The Linux Kernel Mailing List
 help / color / mirror / Atom feed
* [PATCH] sched_ext: Add scx_ai_numa scheduler example for AI workloads
@ 2026-05-08  7:51 Qiliang Yuan
  2026-05-08  7:56 ` Andrea Righi
  0 siblings, 1 reply; 4+ messages in thread
From: Qiliang Yuan @ 2026-05-08  7:51 UTC (permalink / raw)
  To: Tejun Heo, David Vernet, Andrea Righi, Changwoo Min
  Cc: linux-kernel, sched-ext, bpf, Qiliang Yuan

Implement an AI-focused NUMA-aware scheduler that optimizes task dispatch for
GPU-accelerated AI training. The scheduler maintains per-NUMA-node dispatch
queues to preserve L3 cache warmth and minimize remote DRAM accesses that
would stall GPU kernel launches waiting on CPU preprocessing.

Key features:
- Per-NUMA-node DSQs (dispatch queues) to maintain cache locality
- Idle fast path that bypasses DSQ for minimum latency
- Per-task NUMA affinity tracking to remember task placement
- Work stealing across nodes to prevent starvation during load imbalance

The BPF component (scx_ai_numa.bpf.c) implements the core scheduler
callbacks, while the userspace loader (scx_ai_numa.c) detects NUMA
topology, installs the BPF program, and reports per-node dispatch
statistics every second.

This scheduler is suitable for AI training workloads where GPU command
launches depend on rapid CPU preprocessing with minimal scheduling latency.

Signed-off-by: Qiliang Yuan <realwujing@gmail.com>
---
 tools/sched_ext/Makefile          |   2 +-
 tools/sched_ext/scx_ai_numa.bpf.c | 200 ++++++++++++++++++++++++++++++++++++++
 tools/sched_ext/scx_ai_numa.c     | 126 ++++++++++++++++++++++++
 3 files changed, 327 insertions(+), 1 deletion(-)

diff --git a/tools/sched_ext/Makefile b/tools/sched_ext/Makefile
index 21554f0896923..a639b5bf4f542 100644
--- a/tools/sched_ext/Makefile
+++ b/tools/sched_ext/Makefile
@@ -191,7 +191,7 @@ $(INCLUDE_DIR)/%.bpf.skel.h: $(SCXOBJ_DIR)/%.bpf.o $(INCLUDE_DIR)/vmlinux.h $(BP
 
 SCX_COMMON_DEPS := include/scx/common.h include/scx/user_exit_info.h | $(BINDIR)
 
-c-sched-targets = scx_simple scx_cpu0 scx_qmap scx_central scx_flatcg scx_userland scx_pair scx_sdt
+c-sched-targets = scx_simple scx_cpu0 scx_qmap scx_central scx_flatcg scx_userland scx_pair scx_sdt scx_ai_numa
 
 $(addprefix $(BINDIR)/,$(c-sched-targets)): \
 	$(BINDIR)/%: \
diff --git a/tools/sched_ext/scx_ai_numa.bpf.c b/tools/sched_ext/scx_ai_numa.bpf.c
new file mode 100644
index 0000000000000..89d3b7dd3d474
--- /dev/null
+++ b/tools/sched_ext/scx_ai_numa.bpf.c
@@ -0,0 +1,200 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * scx_ai_numa - AI NUMA-aware scheduler (BPF side)
+ *
+ * Scheduling policy optimized for AI training workloads:
+ *
+ * 1. Per-NUMA-node DSQs: each NUMA node owns a dedicated dispatch queue.
+ *    Tasks are steered to the DSQ of the NUMA node they last ran on,
+ *    preserving L3 cache warmth and reducing remote DRAM accesses that
+ *    stall GPU kernel launches waiting on CPU preprocessing.
+ *
+ * 2. Idle fast path: when an idle CPU is found, bypass the per-node DSQ
+ *    and insert directly into SCX_DSQ_LOCAL for minimum latency.
+ *
+ * 3. Task NUMA affinity: per-task storage tracks the preferred NUMA node
+ *    (updated every time select_cpu() sees the task's prev_cpu).
+ *
+ * 4. Work stealing: if a node's DSQ is empty, try remote nodes in order
+ *    to prevent CPU starvation during load imbalance (e.g., bursty GPU
+ *    command submissions landing on a single NUMA node).
+ */
+#include <scx/common.bpf.h>
+
+char _license[] SEC("license") = "GPL";
+
+UEI_DEFINE(uei);
+
+#define MAX_NUMA_NODES 16
+
+/* One DSQ per NUMA node, IDs 0 .. MAX_NUMA_NODES-1 */
+#define NUMA_DSQ(node) ((u64)(node))
+
+/* Per-task context: remember which NUMA node this task prefers */
+struct task_ctx {
+	u32 preferred_node;
+};
+
+struct {
+	__uint(type, BPF_MAP_TYPE_TASK_STORAGE);
+	__uint(map_flags, BPF_F_NO_PREALLOC);
+	__type(key, int);
+	__type(value, struct task_ctx);
+} task_ctx_stor SEC(".maps");
+
+/* Per-node counters (per-CPU to avoid false sharing) */
+struct node_stat {
+	__u64 local_dsq;	/* fast-path: direct SCX_DSQ_LOCAL insert */
+	__u64 numa_dsq;		/* enqueued to per-node DSQ */
+	__u64 steal;		/* dispatched from a remote node's DSQ */
+};
+
+struct {
+	__uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
+	__uint(key_size, sizeof(u32));
+	__uint(value_size, sizeof(struct node_stat));
+	__uint(max_entries, MAX_NUMA_NODES);
+} node_stats SEC(".maps");
+
+/* Set by userspace after detecting the number of NUMA nodes */
+const volatile u32 nr_nodes = 1;
+
+static __always_inline u32 cpu_to_node(s32 cpu)
+{
+	return __COMPAT_scx_bpf_cpu_node(cpu);
+}
+
+static __always_inline void stat_inc_local(u32 node)
+{
+	struct node_stat *s = bpf_map_lookup_elem(&node_stats, &node);
+
+	if (s)
+		s->local_dsq++;
+}
+
+static __always_inline void stat_inc_numa(u32 node)
+{
+	struct node_stat *s = bpf_map_lookup_elem(&node_stats, &node);
+
+	if (s)
+		s->numa_dsq++;
+}
+
+static __always_inline void stat_inc_steal(u32 node)
+{
+	struct node_stat *s = bpf_map_lookup_elem(&node_stats, &node);
+
+	if (s)
+		s->steal++;
+}
+
+s32 BPF_STRUCT_OPS(ai_numa_select_cpu, struct task_struct *p, s32 prev_cpu, u64 wake_flags)
+{
+	struct task_ctx *tctx;
+	bool is_idle = false;
+	u32 node;
+	s32 cpu;
+
+	/* Update task's preferred NUMA node from prev_cpu */
+	tctx = bpf_task_storage_get(&task_ctx_stor, p, 0,
+				     BPF_LOCAL_STORAGE_GET_F_CREATE);
+	if (tctx) {
+		node = cpu_to_node(prev_cpu);
+		tctx->preferred_node = node < nr_nodes ? node : 0;
+	}
+
+	/*
+	 * Default selection tries prev_cpu first (same LLC), which preserves
+	 * L1/L2/L3 cache across AI loop iterations without extra policy code.
+	 */
+	cpu = scx_bpf_select_cpu_dfl(p, prev_cpu, wake_flags, &is_idle);
+	if (is_idle) {
+		/* Idle CPU found: bypass DSQ for minimum latency */
+		node = cpu_to_node(cpu);
+		stat_inc_local(node);
+		scx_bpf_dsq_insert(p, SCX_DSQ_LOCAL, SCX_SLICE_DFL, 0);
+	}
+
+	return cpu;
+}
+
+void BPF_STRUCT_OPS(ai_numa_enqueue, struct task_struct *p, u64 enq_flags)
+{
+	struct task_ctx *tctx;
+	u32 node = 0;
+
+	/*
+	 * Route to the task's preferred NUMA node DSQ.
+	 * Keeping AI tasks on the same NUMA node as their GPU's host memory
+	 * reduces cross-node DRAM traffic and PCIe DMA stalls.
+	 */
+	tctx = bpf_task_storage_get(&task_ctx_stor, p, 0, 0);
+	if (tctx) {
+		node = tctx->preferred_node;
+		if (node >= nr_nodes)
+			node = 0;
+	}
+
+	stat_inc_numa(node);
+	scx_bpf_dsq_insert(p, NUMA_DSQ(node), SCX_SLICE_DFL, enq_flags);
+}
+
+void BPF_STRUCT_OPS(ai_numa_dispatch, s32 cpu, struct task_struct *prev)
+{
+	u32 my_node = cpu_to_node(cpu);
+	u32 i;
+
+	/* First: consume from our own NUMA node — zero cross-node traffic */
+	if (scx_bpf_dsq_move_to_local(NUMA_DSQ(my_node), 0))
+		return;
+
+	/*
+	 * Work steal from other nodes in order.
+	 * Prevents CPU starvation when one GPU's launch bursts all tasks
+	 * onto a single NUMA node while other nodes sit idle.
+	 */
+	for (i = 0; i < MAX_NUMA_NODES; i++) {
+		u32 node = i;
+
+		if (node >= nr_nodes)
+			break;
+		if (node == my_node)
+			continue;
+		if (scx_bpf_dsq_move_to_local(NUMA_DSQ(node), 0)) {
+			stat_inc_steal(my_node);
+			return;
+		}
+	}
+}
+
+s32 BPF_STRUCT_OPS_SLEEPABLE(ai_numa_init)
+{
+	u32 i;
+	int ret;
+
+	for (i = 0; i < MAX_NUMA_NODES; i++) {
+		if (i >= nr_nodes)
+			break;
+		ret = scx_bpf_create_dsq(NUMA_DSQ(i), -1);
+		if (ret) {
+			scx_bpf_error("failed to create DSQ for node %u: %d",
+				      i, ret);
+			return ret;
+		}
+	}
+
+	return 0;
+}
+
+void BPF_STRUCT_OPS(ai_numa_exit, struct scx_exit_info *ei)
+{
+	UEI_RECORD(uei, ei);
+}
+
+SCX_OPS_DEFINE(ai_numa_ops,
+	       .select_cpu	= (void *)ai_numa_select_cpu,
+	       .enqueue		= (void *)ai_numa_enqueue,
+	       .dispatch	= (void *)ai_numa_dispatch,
+	       .init		= (void *)ai_numa_init,
+	       .exit		= (void *)ai_numa_exit,
+	       .name		= "ai_numa");
diff --git a/tools/sched_ext/scx_ai_numa.c b/tools/sched_ext/scx_ai_numa.c
new file mode 100644
index 0000000000000..58c7bb1bd6bb6
--- /dev/null
+++ b/tools/sched_ext/scx_ai_numa.c
@@ -0,0 +1,126 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * scx_ai_numa - AI NUMA-aware scheduler (userspace loader)
+ *
+ * Detects NUMA topology, configures the BPF scheduler, and prints
+ * per-node dispatch statistics every second.
+ */
+#include <stdio.h>
+#include <unistd.h>
+#include <signal.h>
+#include <assert.h>
+#include <libgen.h>
+#include <sys/stat.h>
+#include <bpf/bpf.h>
+#include <scx/common.h>
+#include "scx_ai_numa.bpf.skel.h"
+
+/* Must match BPF side */
+struct node_stat {
+	__u64 local_dsq;
+	__u64 numa_dsq;
+	__u64 steal;
+};
+
+#define MAX_NUMA_NODES 16
+
+static volatile int exit_req;
+
+static void sigint_handler(int sig)
+{
+	exit_req = 1;
+}
+
+/* Detect NUMA node count by scanning sysfs */
+static __u32 detect_nr_nodes(void)
+{
+	struct stat st;
+	char path[64];
+	__u32 i, count = 0;
+
+	for (i = 0; i < MAX_NUMA_NODES; i++) {
+		snprintf(path, sizeof(path),
+			 "/sys/devices/system/node/node%u", i);
+		if (stat(path, &st) == 0 && S_ISDIR(st.st_mode))
+			count = i + 1;
+		else
+			break;
+	}
+	return count ? count : 1;
+}
+
+static void print_stats(struct scx_ai_numa *skel, __u32 nr_nodes)
+{
+	int nr_cpus = libbpf_num_possible_cpus();
+	int map_fd  = bpf_map__fd(skel->maps.node_stats);
+
+	printf("\n%-6s %14s %14s %14s\n",
+	       "Node", "Local-DSQ", "NUMA-DSQ", "Steals");
+	printf("------+--------------+--------------+--------------\n");
+
+	for (__u32 node = 0; node < nr_nodes; node++) {
+		struct node_stat per_cpu[nr_cpus];
+		struct node_stat total = {};
+		__u32 key = node;
+		int i;
+
+		if (bpf_map_lookup_elem(map_fd, &key, per_cpu) < 0)
+			continue;
+
+		for (i = 0; i < nr_cpus; i++) {
+			total.local_dsq += per_cpu[i].local_dsq;
+			total.numa_dsq  += per_cpu[i].numa_dsq;
+			total.steal     += per_cpu[i].steal;
+		}
+
+		printf("%-6u %14llu %14llu %14llu\n", node,
+		       total.local_dsq, total.numa_dsq, total.steal);
+	}
+}
+
+int main(int argc, char **argv)
+{
+	struct scx_ai_numa *skel;
+	struct bpf_link *link;
+	__u64 ecode;
+	__u32 nr_nodes;
+
+	signal(SIGINT, sigint_handler);
+	signal(SIGTERM, sigint_handler);
+
+	nr_nodes = detect_nr_nodes();
+	printf("scx_ai_numa: detected %u NUMA node(s)\n", nr_nodes);
+
+restart:
+	/*
+	 * Avoid SCX_OPS_OPEN() which accesses sub_attach/sub_detach/
+	 * sub_cgroup_id at compile time. These fields may not be available
+	 * in all supported kernel versions.
+	 */
+	skel = scx_ai_numa__open();
+	SCX_BUG_ON(!skel, "Could not open scx_ai_numa");
+	skel->struct_ops.ai_numa_ops->hotplug_seq = scx_hotplug_seq();
+	SCX_ENUM_INIT(skel);
+
+	/* Pass NUMA topology to the BPF program via rodata */
+	skel->rodata->nr_nodes = nr_nodes;
+
+	SCX_OPS_LOAD(skel, ai_numa_ops, scx_ai_numa, uei);
+	link = SCX_OPS_ATTACH(skel, ai_numa_ops, scx_ai_numa);
+
+	printf("scx_ai_numa: running (Ctrl-C to stop)\n");
+
+	while (!exit_req && !UEI_EXITED(skel, uei)) {
+		print_stats(skel, nr_nodes);
+		fflush(stdout);
+		sleep(1);
+	}
+
+	bpf_link__destroy(link);
+	ecode = UEI_REPORT(skel, uei);
+	scx_ai_numa__destroy(skel);
+
+	if (UEI_ECODE_RESTART(ecode))
+		goto restart;
+	return 0;
+}

---
base-commit: 8ab992f815d6736b5c7a6f5fd7bfe7bc106bb3dc
change-id: 20260508-feat-scx_ai_example-8e1384942646

Best regards,
-- 
Qiliang Yuan <realwujing@gmail.com>


^ permalink raw reply related	[flat|nested] 4+ messages in thread

end of thread, other threads:[~2026-05-08  9:37 UTC | newest]

Thread overview: 4+ messages (download: mbox.gz follow: Atom feed
-- links below jump to the message on this page --
2026-05-08  7:51 [PATCH] sched_ext: Add scx_ai_numa scheduler example for AI workloads Qiliang Yuan
2026-05-08  7:56 ` Andrea Righi
2026-05-08  9:29   ` Christian Loehle
2026-05-08  9:37     ` Andrea Righi

This is a public inbox, see mirroring instructions
for how to clone and mirror all data and code used for this inbox