/*
 * MMUOPS notifier callout functions
 */
static void gru_invalidate_range_start(struct mmu_notifier *mn,
	       struct mm_struct *mm, unsigned long start, unsigned long end)
{
	struct gru_mm_struct *gms = container_of(mn, struct gru_mm_struct,
				ms_notifier);

	atomic_inc(&gms->ms_range_active);
	if (!gms->ms_released)
		gru_flush_tlb_range(gms, start, end - start);
}

static void gru_invalidate_range_end(struct mmu_notifier *mn,
		struct mm_struct *mm, unsigned long start, unsigned long end)
{
	struct gru_mm_struct *gms = container_of(mn, struct gru_mm_struct,
					ms_notifier);

	atomic_dec(&gms->ms_range_active);
	wake_up_all(&gms->ms_wait_queue);
}

static void gru_invalidate_page(struct mmu_notifier *mn, struct mm_struct *mm,
				       unsigned long address)
{
	struct gru_mm_struct *gms = container_of(mn, struct gru_mm_struct,
					ms_notifier);

	if (!gms->ms_released)
		gru_flush_tlb_range(gms, address, address + PAGE_SIZE);
}

static int gru_clear_flush_young(struct mmu_notifier *mn, struct mm_struct *mm,
				       unsigned long address)
{
	return 1;
}

static void gru_mmu_release(struct mmu_notifier *mn, struct mm_struct *mm)
{
	struct gru_mm_struct *gms = container_of(mn, struct gru_mm_struct,
					ms_notifier);

	gms->ms_released = 1;
}

struct mmu_notifier_ops  gru_mmuops = {
	.release = gru_mmu_release,
	.clear_flush_young = gru_clear_flush_young,
	.invalidate_page =  gru_invalidate_page,
	.invalidate_range_start = gru_invalidate_range_start,
	.invalidate_range_end = gru_invalidate_range_end,
};

/* Move this to the basic mmu_notifier file. But for now... */
static struct mmu_notifier *mmu_find_ops(struct mm_struct *mm)
{
	struct mmu_notifier *mn;
	struct hlist_node *n;

	if (mm->mmu_notifier_mm)
		hlist_for_each_entry_rcu(mn, n, &mm->mmu_notifier_mm->list, hlist)
			if (mn->ops == &gru_mmuops)
				return mn;
	return NULL;
}

struct gru_mm_struct *gru_register_mmu_notifier(void)
{
	struct gru_mm_struct *gms;
	struct mmu_notifier *mn;

	mn = mmu_find_ops(current->mm);
	if (mn) {
		gms = container_of(mn, struct gru_mm_struct, ms_notifier);
		atomic_inc(&gms->ms_refcnt);
	} else {
		gms = kzalloc(sizeof(*gms), GFP_KERNEL);
		if (gms) {
			spin_lock_init(&gms->ms_asid_lock);
			gms->ms_notifier.ops = &gru_mmuops;
			atomic_set(&gms->ms_refcnt, 1);
			init_waitqueue_head(&gms->ms_wait_queue);
			mmu_notifier_register(&gms->ms_notifier, current->mm);
			synchronize_rcu();
		}
	}
	return gms;
}

void gru_drop_mmu_notifier(struct gru_mm_struct *gms)
{
	if (atomic_dec_return(&gms->ms_refcnt) == 0) {
		if (!gms->ms_released)
			mmu_notifier_unregister(&gms->ms_notifier, current->mm);
		kfree(gms);
	}
}

