diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 27123e597eca0296361cec929eb86e4c169453d9..6e0126230878621e56e3f6883898448a9e66a7e1 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -163,6 +163,7 @@ struct mem_cgroup {
 
 	/* Accounted resources */
 	struct page_counter memory;
+	struct page_counter swap;
 
 	/* Legacy consumer-oriented counters */
 	struct page_counter memsw;
diff --git a/include/linux/swap.h b/include/linux/swap.h
index 414e101cd06195fe60339e71a8dfd2487001df7e..83b95f343ab10b25b09ca38fac839a87ac3f095e 100644
--- a/include/linux/swap.h
+++ b/include/linux/swap.h
@@ -368,11 +368,17 @@ static inline int mem_cgroup_swappiness(struct mem_cgroup *mem)
 #endif
 #ifdef CONFIG_MEMCG_SWAP
 extern void mem_cgroup_swapout(struct page *page, swp_entry_t entry);
+extern int mem_cgroup_try_charge_swap(struct page *page, swp_entry_t entry);
 extern void mem_cgroup_uncharge_swap(swp_entry_t entry);
 #else
 static inline void mem_cgroup_swapout(struct page *page, swp_entry_t entry)
 {
 }
+static inline int mem_cgroup_try_charge_swap(struct page *page,
+					     swp_entry_t entry)
+{
+	return 0;
+}
 static inline void mem_cgroup_uncharge_swap(swp_entry_t entry)
 {
 }
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index f6bc78f4ed137bd4fd3b46b439170b6ebc485bd0..1ff552e3722b0eb8eafb609d40ad3a7cccd18311 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -1220,7 +1220,7 @@ void mem_cgroup_print_oom_info(struct mem_cgroup *memcg, struct task_struct *p)
 		pr_cont(":");
 
 		for (i = 0; i < MEM_CGROUP_STAT_NSTATS; i++) {
-			if (i == MEM_CGROUP_STAT_SWAP && !do_memsw_account())
+			if (i == MEM_CGROUP_STAT_SWAP && !do_swap_account)
 				continue;
 			pr_cont(" %s:%luKB", mem_cgroup_stat_names[i],
 				K(mem_cgroup_read_stat(iter, i)));
@@ -1259,9 +1259,12 @@ static unsigned long mem_cgroup_get_limit(struct mem_cgroup *memcg)
 	limit = memcg->memory.limit;
 	if (mem_cgroup_swappiness(memcg)) {
 		unsigned long memsw_limit;
+		unsigned long swap_limit;
 
 		memsw_limit = memcg->memsw.limit;
-		limit = min(limit + total_swap_pages, memsw_limit);
+		swap_limit = memcg->swap.limit;
+		swap_limit = min(swap_limit, (unsigned long)total_swap_pages);
+		limit = min(limit + swap_limit, memsw_limit);
 	}
 	return limit;
 }
@@ -4201,11 +4204,13 @@ mem_cgroup_css_alloc(struct cgroup_subsys_state *parent_css)
 	if (parent && parent->use_hierarchy) {
 		memcg->use_hierarchy = true;
 		page_counter_init(&memcg->memory, &parent->memory);
+		page_counter_init(&memcg->swap, &parent->swap);
 		page_counter_init(&memcg->memsw, &parent->memsw);
 		page_counter_init(&memcg->kmem, &parent->kmem);
 		page_counter_init(&memcg->tcpmem, &parent->tcpmem);
 	} else {
 		page_counter_init(&memcg->memory, NULL);
+		page_counter_init(&memcg->swap, NULL);
 		page_counter_init(&memcg->memsw, NULL);
 		page_counter_init(&memcg->kmem, NULL);
 		page_counter_init(&memcg->tcpmem, NULL);
@@ -5224,7 +5229,7 @@ int mem_cgroup_try_charge(struct page *page, struct mm_struct *mm,
 		if (page->mem_cgroup)
 			goto out;
 
-		if (do_memsw_account()) {
+		if (do_swap_account) {
 			swp_entry_t ent = { .val = page_private(page), };
 			unsigned short id = lookup_swap_cgroup_id(ent);
 
@@ -5677,26 +5682,66 @@ void mem_cgroup_swapout(struct page *page, swp_entry_t entry)
 	memcg_check_events(memcg, page);
 }
 
+/*
+ * mem_cgroup_try_charge_swap - try charging a swap entry
+ * @page: page being added to swap
+ * @entry: swap entry to charge
+ *
+ * Try to charge @entry to the memcg that @page belongs to.
+ *
+ * Returns 0 on success, -ENOMEM on failure.
+ */
+int mem_cgroup_try_charge_swap(struct page *page, swp_entry_t entry)
+{
+	struct mem_cgroup *memcg;
+	struct page_counter *counter;
+	unsigned short oldid;
+
+	if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) || !do_swap_account)
+		return 0;
+
+	memcg = page->mem_cgroup;
+
+	/* Readahead page, never charged */
+	if (!memcg)
+		return 0;
+
+	if (!mem_cgroup_is_root(memcg) &&
+	    !page_counter_try_charge(&memcg->swap, 1, &counter))
+		return -ENOMEM;
+
+	oldid = swap_cgroup_record(entry, mem_cgroup_id(memcg));
+	VM_BUG_ON_PAGE(oldid, page);
+	mem_cgroup_swap_statistics(memcg, true);
+
+	css_get(&memcg->css);
+	return 0;
+}
+
 /**
  * mem_cgroup_uncharge_swap - uncharge a swap entry
  * @entry: swap entry to uncharge
  *
- * Drop the memsw charge associated with @entry.
+ * Drop the swap charge associated with @entry.
  */
 void mem_cgroup_uncharge_swap(swp_entry_t entry)
 {
 	struct mem_cgroup *memcg;
 	unsigned short id;
 
-	if (!do_memsw_account())
+	if (!do_swap_account)
 		return;
 
 	id = swap_cgroup_record(entry, 0);
 	rcu_read_lock();
 	memcg = mem_cgroup_from_id(id);
 	if (memcg) {
-		if (!mem_cgroup_is_root(memcg))
-			page_counter_uncharge(&memcg->memsw, 1);
+		if (!mem_cgroup_is_root(memcg)) {
+			if (cgroup_subsys_on_dfl(memory_cgrp_subsys))
+				page_counter_uncharge(&memcg->swap, 1);
+			else
+				page_counter_uncharge(&memcg->memsw, 1);
+		}
 		mem_cgroup_swap_statistics(memcg, false);
 		css_put(&memcg->css);
 	}
@@ -5720,6 +5765,63 @@ static int __init enable_swap_account(char *s)
 }
 __setup("swapaccount=", enable_swap_account);
 
+static u64 swap_current_read(struct cgroup_subsys_state *css,
+			     struct cftype *cft)
+{
+	struct mem_cgroup *memcg = mem_cgroup_from_css(css);
+
+	return (u64)page_counter_read(&memcg->swap) * PAGE_SIZE;
+}
+
+static int swap_max_show(struct seq_file *m, void *v)
+{
+	struct mem_cgroup *memcg = mem_cgroup_from_css(seq_css(m));
+	unsigned long max = READ_ONCE(memcg->swap.limit);
+
+	if (max == PAGE_COUNTER_MAX)
+		seq_puts(m, "max\n");
+	else
+		seq_printf(m, "%llu\n", (u64)max * PAGE_SIZE);
+
+	return 0;
+}
+
+static ssize_t swap_max_write(struct kernfs_open_file *of,
+			      char *buf, size_t nbytes, loff_t off)
+{
+	struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
+	unsigned long max;
+	int err;
+
+	buf = strstrip(buf);
+	err = page_counter_memparse(buf, "max", &max);
+	if (err)
+		return err;
+
+	mutex_lock(&memcg_limit_mutex);
+	err = page_counter_limit(&memcg->swap, max);
+	mutex_unlock(&memcg_limit_mutex);
+	if (err)
+		return err;
+
+	return nbytes;
+}
+
+static struct cftype swap_files[] = {
+	{
+		.name = "swap.current",
+		.flags = CFTYPE_NOT_ON_ROOT,
+		.read_u64 = swap_current_read,
+	},
+	{
+		.name = "swap.max",
+		.flags = CFTYPE_NOT_ON_ROOT,
+		.seq_show = swap_max_show,
+		.write = swap_max_write,
+	},
+	{ }	/* terminate */
+};
+
 static struct cftype memsw_cgroup_files[] = {
 	{
 		.name = "memsw.usage_in_bytes",
@@ -5751,6 +5853,8 @@ static int __init mem_cgroup_swap_init(void)
 {
 	if (!mem_cgroup_disabled() && really_do_swap_account) {
 		do_swap_account = 1;
+		WARN_ON(cgroup_add_dfl_cftypes(&memory_cgrp_subsys,
+					       swap_files));
 		WARN_ON(cgroup_add_legacy_cftypes(&memory_cgrp_subsys,
 						  memsw_cgroup_files));
 	}
diff --git a/mm/shmem.c b/mm/shmem.c
index b98e1011858cdefc67108291a3717a48d0324096..fa2ceb2d2655dbd8e06ecd4e02384df5ad5b5891 100644
--- a/mm/shmem.c
+++ b/mm/shmem.c
@@ -912,6 +912,9 @@ static int shmem_writepage(struct page *page, struct writeback_control *wbc)
 	if (!swap.val)
 		goto redirty;
 
+	if (mem_cgroup_try_charge_swap(page, swap))
+		goto free_swap;
+
 	/*
 	 * Add inode to shmem_unuse()'s list of swapped-out inodes,
 	 * if it's not already there.  Do it now before the page is
@@ -940,6 +943,7 @@ static int shmem_writepage(struct page *page, struct writeback_control *wbc)
 	}
 
 	mutex_unlock(&shmem_swaplist_mutex);
+free_swap:
 	swapcache_free(swap);
 redirty:
 	set_page_dirty(page);
diff --git a/mm/swap_state.c b/mm/swap_state.c
index 676ff2991380120275ba5d81d9660592bdc15b75..69cb2464e7dcd598dcf18e8601f97c1c4f8bb627 100644
--- a/mm/swap_state.c
+++ b/mm/swap_state.c
@@ -170,6 +170,11 @@ int add_to_swap(struct page *page, struct list_head *list)
 	if (!entry.val)
 		return 0;
 
+	if (mem_cgroup_try_charge_swap(page, entry)) {
+		swapcache_free(entry);
+		return 0;
+	}
+
 	if (unlikely(PageTransHuge(page)))
 		if (unlikely(split_huge_page_to_list(page, list))) {
 			swapcache_free(entry);
diff --git a/mm/swapfile.c b/mm/swapfile.c
index 2bb30aa3a4123a547bd29e1585a1234644a92152..22a7a1fc1e478bdb0b6171057855f1dc48ac2d06 100644
--- a/mm/swapfile.c
+++ b/mm/swapfile.c
@@ -785,14 +785,12 @@ static unsigned char swap_entry_free(struct swap_info_struct *p,
 			count--;
 	}
 
-	if (!count)
-		mem_cgroup_uncharge_swap(entry);
-
 	usage = count | has_cache;
 	p->swap_map[offset] = usage;
 
 	/* free if no reference */
 	if (!usage) {
+		mem_cgroup_uncharge_swap(entry);
 		dec_cluster_info_page(p, p->cluster_info, offset);
 		if (offset < p->lowest_bit)
 			p->lowest_bit = offset;