diff options
-rw-r--r-- | drivers/iommu/amd/iommu.c | 42 |
1 files changed, 26 insertions, 16 deletions
diff --git a/drivers/iommu/amd/iommu.c b/drivers/iommu/amd/iommu.c index 8364cd6fa47d..6285fd1afd50 100644 --- a/drivers/iommu/amd/iommu.c +++ b/drivers/iommu/amd/iommu.c @@ -2265,28 +2265,36 @@ void protection_domain_free(struct protection_domain *domain) struct protection_domain *protection_domain_alloc(unsigned int type, int nid) { - struct io_pgtable_ops *pgtbl_ops; struct protection_domain *domain; - int pgtable; domain = kzalloc(sizeof(*domain), GFP_KERNEL); if (!domain) return NULL; domain->id = domain_id_alloc(); - if (!domain->id) - goto err_free; + if (!domain->id) { + kfree(domain); + return NULL; + } spin_lock_init(&domain->lock); INIT_LIST_HEAD(&domain->dev_list); INIT_LIST_HEAD(&domain->dev_data_list); domain->iop.pgtbl.cfg.amd.nid = nid; + return domain; +} + +static int pdom_setup_pgtable(struct protection_domain *domain, + unsigned int type) +{ + struct io_pgtable_ops *pgtbl_ops; + int pgtable; + switch (type) { /* No need to allocate io pgtable ops in passthrough mode */ case IOMMU_DOMAIN_IDENTITY: - case IOMMU_DOMAIN_SVA: - return domain; + return 0; case IOMMU_DOMAIN_DMA: pgtable = amd_iommu_pgtable; break; @@ -2298,7 +2306,7 @@ struct protection_domain *protection_domain_alloc(unsigned int type, int nid) pgtable = AMD_IOMMU_V1; break; default: - goto err_id; + return -EINVAL; } switch (pgtable) { @@ -2309,20 +2317,14 @@ struct protection_domain *protection_domain_alloc(unsigned int type, int nid) domain->pd_mode = PD_MODE_V2; break; default: - goto err_id; + return -EINVAL; } - pgtbl_ops = alloc_io_pgtable_ops(pgtable, &domain->iop.pgtbl.cfg, domain); if (!pgtbl_ops) - goto err_id; + return -ENOMEM; - return domain; -err_id: - domain_id_free(domain->id); -err_free: - kfree(domain); - return NULL; + return 0; } static inline u64 dma_max_address(void) @@ -2345,6 +2347,7 @@ static struct iommu_domain *do_iommu_domain_alloc(unsigned int type, bool dirty_tracking = flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING; struct protection_domain *domain; struct amd_iommu *iommu = NULL; + int ret; if (dev) iommu = get_amd_iommu_from_dev(dev); @@ -2364,6 +2367,13 @@ static struct iommu_domain *do_iommu_domain_alloc(unsigned int type, if (!domain) return ERR_PTR(-ENOMEM); + ret = pdom_setup_pgtable(domain, type); + if (ret) { + domain_id_free(domain->id); + kfree(domain); + return ERR_PTR(ret); + } + domain->domain.geometry.aperture_start = 0; domain->domain.geometry.aperture_end = dma_max_address(); domain->domain.geometry.force_aperture = true; |