summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--drivers/iommu/amd/iommu.c42
1 files changed, 26 insertions, 16 deletions
diff --git a/drivers/iommu/amd/iommu.c b/drivers/iommu/amd/iommu.c
index 8364cd6fa47d..6285fd1afd50 100644
--- a/drivers/iommu/amd/iommu.c
+++ b/drivers/iommu/amd/iommu.c
@@ -2265,28 +2265,36 @@ void protection_domain_free(struct protection_domain *domain)
struct protection_domain *protection_domain_alloc(unsigned int type, int nid)
{
- struct io_pgtable_ops *pgtbl_ops;
struct protection_domain *domain;
- int pgtable;
domain = kzalloc(sizeof(*domain), GFP_KERNEL);
if (!domain)
return NULL;
domain->id = domain_id_alloc();
- if (!domain->id)
- goto err_free;
+ if (!domain->id) {
+ kfree(domain);
+ return NULL;
+ }
spin_lock_init(&domain->lock);
INIT_LIST_HEAD(&domain->dev_list);
INIT_LIST_HEAD(&domain->dev_data_list);
domain->iop.pgtbl.cfg.amd.nid = nid;
+ return domain;
+}
+
+static int pdom_setup_pgtable(struct protection_domain *domain,
+ unsigned int type)
+{
+ struct io_pgtable_ops *pgtbl_ops;
+ int pgtable;
+
switch (type) {
/* No need to allocate io pgtable ops in passthrough mode */
case IOMMU_DOMAIN_IDENTITY:
- case IOMMU_DOMAIN_SVA:
- return domain;
+ return 0;
case IOMMU_DOMAIN_DMA:
pgtable = amd_iommu_pgtable;
break;
@@ -2298,7 +2306,7 @@ struct protection_domain *protection_domain_alloc(unsigned int type, int nid)
pgtable = AMD_IOMMU_V1;
break;
default:
- goto err_id;
+ return -EINVAL;
}
switch (pgtable) {
@@ -2309,20 +2317,14 @@ struct protection_domain *protection_domain_alloc(unsigned int type, int nid)
domain->pd_mode = PD_MODE_V2;
break;
default:
- goto err_id;
+ return -EINVAL;
}
-
pgtbl_ops =
alloc_io_pgtable_ops(pgtable, &domain->iop.pgtbl.cfg, domain);
if (!pgtbl_ops)
- goto err_id;
+ return -ENOMEM;
- return domain;
-err_id:
- domain_id_free(domain->id);
-err_free:
- kfree(domain);
- return NULL;
+ return 0;
}
static inline u64 dma_max_address(void)
@@ -2345,6 +2347,7 @@ static struct iommu_domain *do_iommu_domain_alloc(unsigned int type,
bool dirty_tracking = flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
struct protection_domain *domain;
struct amd_iommu *iommu = NULL;
+ int ret;
if (dev)
iommu = get_amd_iommu_from_dev(dev);
@@ -2364,6 +2367,13 @@ static struct iommu_domain *do_iommu_domain_alloc(unsigned int type,
if (!domain)
return ERR_PTR(-ENOMEM);
+ ret = pdom_setup_pgtable(domain, type);
+ if (ret) {
+ domain_id_free(domain->id);
+ kfree(domain);
+ return ERR_PTR(ret);
+ }
+
domain->domain.geometry.aperture_start = 0;
domain->domain.geometry.aperture_end = dma_max_address();
domain->domain.geometry.force_aperture = true;