diff --git a/include/linux/hmm.h b/include/linux/hmm.h index bf013e965257..0fa8ea34ccef 100644 --- a/include/linux/hmm.h +++ b/include/linux/hmm.h @@ -86,7 +86,7 @@ struct hmm { struct mm_struct *mm; struct kref kref; - struct mutex lock; + spinlock_t ranges_lock; struct list_head ranges; struct list_head mirrors; struct mmu_notifier mmu_notifier; diff --git a/mm/hmm.c b/mm/hmm.c index b224ea635a77..de35289df20d 100644 --- a/mm/hmm.c +++ b/mm/hmm.c @@ -64,7 +64,7 @@ static struct hmm *hmm_get_or_create(struct mm_struct *mm) init_rwsem(&hmm->mirrors_sem); hmm->mmu_notifier.ops = NULL; INIT_LIST_HEAD(&hmm->ranges); - mutex_init(&hmm->lock); + spin_lock_init(&hmm->ranges_lock); kref_init(&hmm->kref); hmm->notifiers = 0; hmm->mm = mm; @@ -144,6 +144,25 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm) hmm_put(hmm); } +static void notifiers_decrement(struct hmm *hmm) +{ + unsigned long flags; + + spin_lock_irqsave(&hmm->ranges_lock, flags); + hmm->notifiers--; + if (!hmm->notifiers) { + struct hmm_range *range; + + list_for_each_entry(range, &hmm->ranges, list) { + if (range->valid) + continue; + range->valid = true; + } + wake_up_all(&hmm->wq); + } + spin_unlock_irqrestore(&hmm->ranges_lock, flags); +} + static int hmm_invalidate_range_start(struct mmu_notifier *mn, const struct mmu_notifier_range *nrange) { @@ -151,6 +170,7 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn, struct hmm_mirror *mirror; struct hmm_update update; struct hmm_range *range; + unsigned long flags; int ret = 0; if (!kref_get_unless_zero(&hmm->kref)) @@ -161,12 +181,7 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn, update.event = HMM_UPDATE_INVALIDATE; update.blockable = mmu_notifier_range_blockable(nrange); - if (mmu_notifier_range_blockable(nrange)) - mutex_lock(&hmm->lock); - else if (!mutex_trylock(&hmm->lock)) { - ret = -EAGAIN; - goto out; - } + spin_lock_irqsave(&hmm->ranges_lock, flags); hmm->notifiers++; list_for_each_entry(range, &hmm->ranges, list) { if (update.end < range->start || update.start >= range->end) @@ -174,7 +189,7 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn, range->valid = false; } - mutex_unlock(&hmm->lock); + spin_unlock_irqrestore(&hmm->ranges_lock, flags); if (mmu_notifier_range_blockable(nrange)) down_read(&hmm->mirrors_sem); @@ -182,16 +197,23 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn, ret = -EAGAIN; goto out; } - list_for_each_entry(mirror, &hmm->mirrors, list) { - int ret; - ret = mirror->ops->sync_cpu_device_pagetables(mirror, &update); - if (!update.blockable && ret == -EAGAIN) + list_for_each_entry(mirror, &hmm->mirrors, list) { + int rc; + + rc = mirror->ops->sync_cpu_device_pagetables(mirror, &update); + if (rc) { + if (WARN_ON(update.blockable || rc != -EAGAIN)) + continue; + ret = -EAGAIN; break; + } } up_read(&hmm->mirrors_sem); out: + if (ret) + notifiers_decrement(hmm); hmm_put(hmm); return ret; } @@ -204,20 +226,7 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn, if (!kref_get_unless_zero(&hmm->kref)) return; - mutex_lock(&hmm->lock); - hmm->notifiers--; - if (!hmm->notifiers) { - struct hmm_range *range; - - list_for_each_entry(range, &hmm->ranges, list) { - if (range->valid) - continue; - range->valid = true; - } - wake_up_all(&hmm->wq); - } - mutex_unlock(&hmm->lock); - + notifiers_decrement(hmm); hmm_put(hmm); } @@ -868,6 +877,7 @@ int hmm_range_register(struct hmm_range *range, { unsigned long mask = ((1UL << page_shift) - 1UL); struct hmm *hmm = mirror->hmm; + unsigned long flags; range->valid = false; range->hmm = NULL; @@ -886,7 +896,7 @@ int hmm_range_register(struct hmm_range *range, return -EFAULT; /* Initialize range to track CPU page table updates. */ - mutex_lock(&hmm->lock); + spin_lock_irqsave(&hmm->ranges_lock, flags); range->hmm = hmm; kref_get(&hmm->kref); @@ -898,7 +908,7 @@ int hmm_range_register(struct hmm_range *range, */ if (!hmm->notifiers) range->valid = true; - mutex_unlock(&hmm->lock); + spin_unlock_irqrestore(&hmm->ranges_lock, flags); return 0; } @@ -914,10 +924,11 @@ EXPORT_SYMBOL(hmm_range_register); void hmm_range_unregister(struct hmm_range *range) { struct hmm *hmm = range->hmm; + unsigned long flags; - mutex_lock(&hmm->lock); + spin_lock_irqsave(&hmm->ranges_lock, flags); list_del_init(&range->list); - mutex_unlock(&hmm->lock); + spin_unlock_irqrestore(&hmm->ranges_lock, flags); /* Drop reference taken by hmm_range_register() */ mmput(hmm->mm);