diff --git a/include/linux/ksm.h b/include/linux/ksm.h
index a485c14ecd5d52f892cbfe5ca3e22f0463c42a48..2d64ff30c0de44e1b62f172a5cf4af8094aacc65 100644
--- a/include/linux/ksm.h
+++ b/include/linux/ksm.h
@@ -12,11 +12,14 @@
 #include <linux/sched.h>
 #include <linux/vmstat.h>
 
+struct mmu_gather;
+
 #ifdef CONFIG_KSM
 int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
 		unsigned long end, int advice, unsigned long *vm_flags);
 int __ksm_enter(struct mm_struct *mm);
-void __ksm_exit(struct mm_struct *mm);
+void __ksm_exit(struct mm_struct *mm,
+		struct mmu_gather **tlbp, unsigned long end);
 
 static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)
 {
@@ -25,10 +28,24 @@ static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)
 	return 0;
 }
 
-static inline void ksm_exit(struct mm_struct *mm)
+/*
+ * For KSM to handle OOM without deadlock when it's breaking COW in a
+ * likely victim of the OOM killer, exit_mmap() has to serialize with
+ * ksm_exit() after freeing mm's pages but before freeing its page tables.
+ * That leaves a window in which KSM might refault pages which have just
+ * been finally unmapped: guard against that with ksm_test_exit(), and
+ * use it after getting mmap_sem in ksm.c, to check if mm is exiting.
+ */
+static inline bool ksm_test_exit(struct mm_struct *mm)
+{
+	return atomic_read(&mm->mm_users) == 0;
+}
+
+static inline void ksm_exit(struct mm_struct *mm,
+			    struct mmu_gather **tlbp, unsigned long end)
 {
 	if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
-		__ksm_exit(mm);
+		__ksm_exit(mm, tlbp, end);
 }
 
 /*
@@ -64,7 +81,13 @@ static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)
 	return 0;
 }
 
-static inline void ksm_exit(struct mm_struct *mm)
+static inline bool ksm_test_exit(struct mm_struct *mm)
+{
+	return 0;
+}
+
+static inline void ksm_exit(struct mm_struct *mm,
+			    struct mmu_gather **tlbp, unsigned long end)
 {
 }
 
diff --git a/kernel/fork.c b/kernel/fork.c
index 73a442b7be6dff8a2abce0d3537a3cd9cb0db910..42f20f565b16b96786a98042798230ec2080007e 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -501,7 +501,6 @@ void mmput(struct mm_struct *mm)
 
 	if (atomic_dec_and_test(&mm->mm_users)) {
 		exit_aio(mm);
-		ksm_exit(mm);
 		exit_mmap(mm);
 		set_mm_exe_file(mm, NULL);
 		if (!list_empty(&mm->mmlist)) {
diff --git a/mm/ksm.c b/mm/ksm.c
index 7e4d255dadc0609c51e36951e6c22f046ccba234..722e3f2a8dc5b56f768287096e6787c57b6f2469 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -32,6 +32,7 @@
 #include <linux/mmu_notifier.h>
 #include <linux/ksm.h>
 
+#include <asm/tlb.h>
 #include <asm/tlbflush.h>
 
 /*
@@ -347,6 +348,8 @@ static void break_cow(struct mm_struct *mm, unsigned long addr)
 	struct vm_area_struct *vma;
 
 	down_read(&mm->mmap_sem);
+	if (ksm_test_exit(mm))
+		goto out;
 	vma = find_vma(mm, addr);
 	if (!vma || vma->vm_start > addr)
 		goto out;
@@ -365,6 +368,8 @@ static struct page *get_mergeable_page(struct rmap_item *rmap_item)
 	struct page *page;
 
 	down_read(&mm->mmap_sem);
+	if (ksm_test_exit(mm))
+		goto out;
 	vma = find_vma(mm, addr);
 	if (!vma || vma->vm_start > addr)
 		goto out;
@@ -439,11 +444,11 @@ static void remove_rmap_item_from_tree(struct rmap_item *rmap_item)
 	} else if (rmap_item->address & NODE_FLAG) {
 		unsigned char age;
 		/*
-		 * ksm_thread can and must skip the rb_erase, because
+		 * Usually ksmd can and must skip the rb_erase, because
 		 * root_unstable_tree was already reset to RB_ROOT.
-		 * But __ksm_exit has to be careful: do the rb_erase
-		 * if it's interrupting a scan, and this rmap_item was
-		 * inserted by this scan rather than left from before.
+		 * But be careful when an mm is exiting: do the rb_erase
+		 * if this rmap_item was inserted by this scan, rather
+		 * than left over from before.
 		 */
 		age = (unsigned char)(ksm_scan.seqnr - rmap_item->address);
 		BUG_ON(age > 1);
@@ -491,6 +496,8 @@ static int unmerge_ksm_pages(struct vm_area_struct *vma,
 	int err = 0;
 
 	for (addr = start; addr < end && !err; addr += PAGE_SIZE) {
+		if (ksm_test_exit(vma->vm_mm))
+			break;
 		if (signal_pending(current))
 			err = -ERESTARTSYS;
 		else
@@ -507,34 +514,50 @@ static int unmerge_and_remove_all_rmap_items(void)
 	int err = 0;
 
 	spin_lock(&ksm_mmlist_lock);
-	mm_slot = list_entry(ksm_mm_head.mm_list.next,
+	ksm_scan.mm_slot = list_entry(ksm_mm_head.mm_list.next,
 						struct mm_slot, mm_list);
 	spin_unlock(&ksm_mmlist_lock);
 
-	while (mm_slot != &ksm_mm_head) {
+	for (mm_slot = ksm_scan.mm_slot;
+			mm_slot != &ksm_mm_head; mm_slot = ksm_scan.mm_slot) {
 		mm = mm_slot->mm;
 		down_read(&mm->mmap_sem);
 		for (vma = mm->mmap; vma; vma = vma->vm_next) {
+			if (ksm_test_exit(mm))
+				break;
 			if (!(vma->vm_flags & VM_MERGEABLE) || !vma->anon_vma)
 				continue;
 			err = unmerge_ksm_pages(vma,
 						vma->vm_start, vma->vm_end);
-			if (err) {
-				up_read(&mm->mmap_sem);
-				goto out;
-			}
+			if (err)
+				goto error;
 		}
+
 		remove_trailing_rmap_items(mm_slot, mm_slot->rmap_list.next);
-		up_read(&mm->mmap_sem);
 
 		spin_lock(&ksm_mmlist_lock);
-		mm_slot = list_entry(mm_slot->mm_list.next,
+		ksm_scan.mm_slot = list_entry(mm_slot->mm_list.next,
 						struct mm_slot, mm_list);
-		spin_unlock(&ksm_mmlist_lock);
+		if (ksm_test_exit(mm)) {
+			hlist_del(&mm_slot->link);
+			list_del(&mm_slot->mm_list);
+			spin_unlock(&ksm_mmlist_lock);
+
+			free_mm_slot(mm_slot);
+			clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+			up_read(&mm->mmap_sem);
+			mmdrop(mm);
+		} else {
+			spin_unlock(&ksm_mmlist_lock);
+			up_read(&mm->mmap_sem);
+		}
 	}
 
 	ksm_scan.seqnr = 0;
-out:
+	return 0;
+
+error:
+	up_read(&mm->mmap_sem);
 	spin_lock(&ksm_mmlist_lock);
 	ksm_scan.mm_slot = &ksm_mm_head;
 	spin_unlock(&ksm_mmlist_lock);
@@ -755,6 +778,9 @@ static int try_to_merge_with_ksm_page(struct mm_struct *mm1,
 	int err = -EFAULT;
 
 	down_read(&mm1->mmap_sem);
+	if (ksm_test_exit(mm1))
+		goto out;
+
 	vma = find_vma(mm1, addr1);
 	if (!vma || vma->vm_start > addr1)
 		goto out;
@@ -796,6 +822,10 @@ static int try_to_merge_two_pages(struct mm_struct *mm1, unsigned long addr1,
 		return err;
 
 	down_read(&mm1->mmap_sem);
+	if (ksm_test_exit(mm1)) {
+		up_read(&mm1->mmap_sem);
+		goto out;
+	}
 	vma = find_vma(mm1, addr1);
 	if (!vma || vma->vm_start > addr1) {
 		up_read(&mm1->mmap_sem);
@@ -1174,7 +1204,12 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page)
 
 	mm = slot->mm;
 	down_read(&mm->mmap_sem);
-	for (vma = find_vma(mm, ksm_scan.address); vma; vma = vma->vm_next) {
+	if (ksm_test_exit(mm))
+		vma = NULL;
+	else
+		vma = find_vma(mm, ksm_scan.address);
+
+	for (; vma; vma = vma->vm_next) {
 		if (!(vma->vm_flags & VM_MERGEABLE))
 			continue;
 		if (ksm_scan.address < vma->vm_start)
@@ -1183,6 +1218,8 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page)
 			ksm_scan.address = vma->vm_end;
 
 		while (ksm_scan.address < vma->vm_end) {
+			if (ksm_test_exit(mm))
+				break;
 			*page = follow_page(vma, ksm_scan.address, FOLL_GET);
 			if (*page && PageAnon(*page)) {
 				flush_anon_page(vma, *page, ksm_scan.address);
@@ -1205,6 +1242,11 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page)
 		}
 	}
 
+	if (ksm_test_exit(mm)) {
+		ksm_scan.address = 0;
+		ksm_scan.rmap_item = list_entry(&slot->rmap_list,
+						struct rmap_item, link);
+	}
 	/*
 	 * Nuke all the rmap_items that are above this current rmap:
 	 * because there were no VM_MERGEABLE vmas with such addresses.
@@ -1219,24 +1261,29 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page)
 		 * We've completed a full scan of all vmas, holding mmap_sem
 		 * throughout, and found no VM_MERGEABLE: so do the same as
 		 * __ksm_exit does to remove this mm from all our lists now.
+		 * This applies either when cleaning up after __ksm_exit
+		 * (but beware: we can reach here even before __ksm_exit),
+		 * or when all VM_MERGEABLE areas have been unmapped (and
+		 * mmap_sem then protects against race with MADV_MERGEABLE).
 		 */
 		hlist_del(&slot->link);
 		list_del(&slot->mm_list);
+		spin_unlock(&ksm_mmlist_lock);
+
 		free_mm_slot(slot);
 		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+		up_read(&mm->mmap_sem);
+		mmdrop(mm);
+	} else {
+		spin_unlock(&ksm_mmlist_lock);
+		up_read(&mm->mmap_sem);
 	}
-	spin_unlock(&ksm_mmlist_lock);
-	up_read(&mm->mmap_sem);
 
 	/* Repeat until we've completed scanning the whole list */
 	slot = ksm_scan.mm_slot;
 	if (slot != &ksm_mm_head)
 		goto next_mm;
 
-	/*
-	 * Bump seqnr here rather than at top, so that __ksm_exit
-	 * can skip rb_erase on unstable tree until we run again.
-	 */
 	ksm_scan.seqnr++;
 	return NULL;
 }
@@ -1361,6 +1408,7 @@ int __ksm_enter(struct mm_struct *mm)
 	spin_unlock(&ksm_mmlist_lock);
 
 	set_bit(MMF_VM_MERGEABLE, &mm->flags);
+	atomic_inc(&mm->mm_count);
 
 	if (needs_wakeup)
 		wake_up_interruptible(&ksm_thread_wait);
@@ -1368,41 +1416,45 @@ int __ksm_enter(struct mm_struct *mm)
 	return 0;
 }
 
-void __ksm_exit(struct mm_struct *mm)
+void __ksm_exit(struct mm_struct *mm,
+		struct mmu_gather **tlbp, unsigned long end)
 {
 	struct mm_slot *mm_slot;
+	int easy_to_free = 0;
 
 	/*
-	 * This process is exiting: doesn't hold and doesn't need mmap_sem;
-	 * but we do need to exclude ksmd and other exiters while we modify
-	 * the various lists and trees.
+	 * This process is exiting: if it's straightforward (as is the
+	 * case when ksmd was never running), free mm_slot immediately.
+	 * But if it's at the cursor or has rmap_items linked to it, use
+	 * mmap_sem to synchronize with any break_cows before pagetables
+	 * are freed, and leave the mm_slot on the list for ksmd to free.
+	 * Beware: ksm may already have noticed it exiting and freed the slot.
 	 */
-	mutex_lock(&ksm_thread_mutex);
+
 	spin_lock(&ksm_mmlist_lock);
 	mm_slot = get_mm_slot(mm);
-	if (!list_empty(&mm_slot->rmap_list)) {
-		spin_unlock(&ksm_mmlist_lock);
-		remove_trailing_rmap_items(mm_slot, mm_slot->rmap_list.next);
-		spin_lock(&ksm_mmlist_lock);
-	}
-
-	if (ksm_scan.mm_slot == mm_slot) {
-		ksm_scan.mm_slot = list_entry(
-			mm_slot->mm_list.next, struct mm_slot, mm_list);
-		ksm_scan.address = 0;
-		ksm_scan.rmap_item = list_entry(
-			&ksm_scan.mm_slot->rmap_list, struct rmap_item, link);
-		if (ksm_scan.mm_slot == &ksm_mm_head)
-			ksm_scan.seqnr++;
+	if (mm_slot && ksm_scan.mm_slot != mm_slot) {
+		if (list_empty(&mm_slot->rmap_list)) {
+			hlist_del(&mm_slot->link);
+			list_del(&mm_slot->mm_list);
+			easy_to_free = 1;
+		} else {
+			list_move(&mm_slot->mm_list,
+				  &ksm_scan.mm_slot->mm_list);
+		}
 	}
-
-	hlist_del(&mm_slot->link);
-	list_del(&mm_slot->mm_list);
 	spin_unlock(&ksm_mmlist_lock);
 
-	free_mm_slot(mm_slot);
-	clear_bit(MMF_VM_MERGEABLE, &mm->flags);
-	mutex_unlock(&ksm_thread_mutex);
+	if (easy_to_free) {
+		free_mm_slot(mm_slot);
+		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+		mmdrop(mm);
+	} else if (mm_slot) {
+		tlb_finish_mmu(*tlbp, 0, end);
+		down_write(&mm->mmap_sem);
+		up_write(&mm->mmap_sem);
+		*tlbp = tlb_gather_mmu(mm, 1);
+	}
 }
 
 #define KSM_ATTR_RO(_name) \
diff --git a/mm/memory.c b/mm/memory.c
index 1a435b81876c3c3c8707fdcc4ec1f43495621c9f..f47ffe9710122451e065fcbdefdd4617fa7db4e5 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -2648,8 +2648,9 @@ static int do_anonymous_page(struct mm_struct *mm, struct vm_area_struct *vma,
 	entry = maybe_mkwrite(pte_mkdirty(entry), vma);
 
 	page_table = pte_offset_map_lock(mm, pmd, address, &ptl);
-	if (!pte_none(*page_table))
+	if (!pte_none(*page_table) || ksm_test_exit(mm))
 		goto release;
+
 	inc_mm_counter(mm, anon_rss);
 	page_add_new_anon_rmap(page, vma, address);
 	set_pte_at(mm, address, page_table, entry);
@@ -2791,7 +2792,7 @@ static int __do_fault(struct mm_struct *mm, struct vm_area_struct *vma,
 	 * handle that later.
 	 */
 	/* Only go through if we didn't race with anybody else... */
-	if (likely(pte_same(*page_table, orig_pte))) {
+	if (likely(pte_same(*page_table, orig_pte) && !ksm_test_exit(mm))) {
 		flush_icache_page(vma, page);
 		entry = mk_pte(page, vma->vm_page_prot);
 		if (flags & FAULT_FLAG_WRITE)
diff --git a/mm/mmap.c b/mm/mmap.c
index 376492ed08f4a66cbc782171cc0924adb1287018..e02f1aa66a1a8176aa953152bf40bc4112e3f26c 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -27,6 +27,7 @@
 #include <linux/mount.h>
 #include <linux/mempolicy.h>
 #include <linux/rmap.h>
+#include <linux/ksm.h>
 #include <linux/mmu_notifier.h>
 #include <linux/perf_event.h>
 
@@ -2111,6 +2112,14 @@ void exit_mmap(struct mm_struct *mm)
 	/* Use -1 here to ensure all VMAs in the mm are unmapped */
 	end = unmap_vmas(&tlb, vma, 0, -1, &nr_accounted, NULL);
 	vm_unacct_memory(nr_accounted);
+
+	/*
+	 * For KSM to handle OOM without deadlock when it's breaking COW in a
+	 * likely victim of the OOM killer, we must serialize with ksm_exit()
+	 * after freeing mm's pages but before freeing its page tables.
+	 */
+	ksm_exit(mm, &tlb, end);
+
 	free_pgtables(tlb, vma, FIRST_USER_ADDRESS, 0);
 	tlb_finish_mmu(tlb, 0, end);