summaryrefslogtreecommitdiff
path: root/mm/memcontrol.c
diff options
context:
space:
mode:
Diffstat (limited to 'mm/memcontrol.c')
-rw-r--r--mm/memcontrol.c25
1 files changed, 20 insertions, 5 deletions
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index afdb35cf4d0e..3801ac1fcfbc 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -678,7 +678,7 @@ struct mem_cgroup *mem_cgroup_from_task(struct task_struct *p)
}
EXPORT_SYMBOL(mem_cgroup_from_task);
-static struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm)
+struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm)
{
struct mem_cgroup *memcg = NULL;
@@ -701,6 +701,15 @@ static struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm)
return memcg;
}
+static struct mem_cgroup *get_mem_cgroup(struct mem_cgroup *memcg)
+{
+ rcu_read_lock();
+ if (!css_tryget_online(&memcg->css))
+ memcg = NULL;
+ rcu_read_unlock();
+ return memcg;
+}
+
/**
* mem_cgroup_iter - iterate over memory cgroup hierarchy
* @root: hierarchy root
@@ -2248,7 +2257,7 @@ static inline bool memcg_kmem_bypass(void)
*/
struct kmem_cache *memcg_kmem_get_cache(struct kmem_cache *cachep)
{
- struct mem_cgroup *memcg;
+ struct mem_cgroup *memcg = NULL;
struct kmem_cache *memcg_cachep;
int kmemcg_id;
@@ -2260,7 +2269,10 @@ struct kmem_cache *memcg_kmem_get_cache(struct kmem_cache *cachep)
if (current->memcg_kmem_skip_account)
return cachep;
- memcg = get_mem_cgroup_from_mm(current->mm);
+ if (current->target_memcg)
+ memcg = get_mem_cgroup(current->target_memcg);
+ if (!memcg)
+ memcg = get_mem_cgroup_from_mm(current->mm);
kmemcg_id = READ_ONCE(memcg->kmemcg_id);
if (kmemcg_id < 0)
goto out;
@@ -2338,13 +2350,16 @@ int memcg_kmem_charge_memcg(struct page *page, gfp_t gfp, int order,
*/
int memcg_kmem_charge(struct page *page, gfp_t gfp, int order)
{
- struct mem_cgroup *memcg;
+ struct mem_cgroup *memcg = NULL;
int ret = 0;
if (memcg_kmem_bypass())
return 0;
- memcg = get_mem_cgroup_from_mm(current->mm);
+ if (current->target_memcg)
+ memcg = get_mem_cgroup(current->target_memcg);
+ if (!memcg)
+ memcg = get_mem_cgroup_from_mm(current->mm);
if (!mem_cgroup_is_root(memcg)) {
ret = memcg_kmem_charge_memcg(page, gfp, order, memcg);
if (!ret)