public inbox for linux-kernel@vger.kernel.org
 help / color / mirror / Atom feed
* [PATCH] bpf: Simplify tnum_step()
@ 2026-03-18 17:19 Hao Sun
  2026-03-19  5:22 ` Shung-Hsi Yu
                   ` (2 more replies)
  0 siblings, 3 replies; 13+ messages in thread
From: Hao Sun @ 2026-03-18 17:19 UTC (permalink / raw)
  To: bpf
  Cc: ast, daniel, andrii, eddyz87, john.fastabend, martin.lau,
	linux-kernel, sunhao.th

Simplify tnum_step() from a 10-variable algorithm into a straight
line sequence of bitwise operations.

tnum_step(): Given a tnum `(tval, tmask)` where `tval & tmask == 0`,
and a value `z` with `tval ≤ z < (tval | tmask)`, find the smallest
`r > z`, a tnum-satisfying value, i.e., `r & ~tmask == tval`.

Every tnum-satisfying value has the form tval | s where s is a subset
of tmask bits (s & ~tmask == 0).  Since tval and tmask are disjoint:

    tval | s  =  tval + s

Similarly z = tval + d where d = z - tval, so r > z becomes:

    tval + s  >  tval + d
    s > d

The problem reduces to: find the smallest s, a subset of tmask, such
that s > d.

Notice that `s` must be a subset of tmask, the problem now is simplified.

The mask bits of `d` form a "counter" that we want to increment by one,
but the counter has gaps at the fixed-bit positions.  A normal +1 would
stop at the first 0-bit it meets; we need it to skip over fixed-bit
gaps and land on the next mask bit.

Step 1 -- plug the gaps:

    d | carry_mask | ~tmask

  - ~tmask fills all fixed-bit positions with 1.
  - carry_mask = (1 << fls64(d & ~tmask)) - 1 fills all positions
    (including mask positions) below the highest non-mask bit of d.

After this, the only remaining 0s are mask bits above the highest
non-mask bit of d where d is also 0 -- exactly the positions where
the carry can validly land.

Step 2 -- increment:

    (d | carry_mask | ~tmask) + 1

Adding 1 flips all trailing 1s to 0 and sets the first 0 to 1.  Since
every gap has been plugged, that first 0 is guaranteed to be a mask bit
above all non-mask bits of d.

Step 3 -- mask:

    ((d | carry_mask | ~tmask) + 1) & tmask

Strip the scaffolding, keeping only mask bits.  Call the result inc.

Step 4 -- result:

    tval | inc

Reattach the fixed bits.

A simple 8-bit example:
    tmask:        1  1  0  1  0  1  1  0
    d:            1  0  1  0  0  0  1  0     (d = 162)
                        ^
                        non-mask 1 at bit 5

With carry_mask = 0b00111111 (smeared from bit 5):

    d|carry|~tm   1  0  1  1  1  1  1  1
    + 1           1  1  0  0  0  0  0  0
    & tmask       1  1  0  0  0  0  0  0

The patch passes my local test: test_verifier, test_prog for
`-t verifier` and `-t reg_bounds`.

Signed-off-by: Hao Sun <hao.sun@inf.ethz.ch>
---
The original algorithm is not intuitive to me, let me know if you
spot any inconsistency.

A Lean4 proof for the correctness of the algorithm is also available
in case anyone is interested:
	[1] https://pastebin.com/raw/czHKiyY0

 kernel/bpf/tnum.c | 32 +++++---------------------------
 1 file changed, 5 insertions(+), 27 deletions(-)

diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
index 4abc359b3db0..aa35d4355216 100644
--- a/kernel/bpf/tnum.c
+++ b/kernel/bpf/tnum.c
@@ -286,8 +286,7 @@ struct tnum tnum_bswap64(struct tnum a)
  */
 u64 tnum_step(struct tnum t, u64 z)
 {
-	u64 tmax, j, p, q, r, s, v, u, w, res;
-	u8 k;
+	u64 tmax, d, carry_mask, inc;
 
 	tmax = t.value | t.mask;
 
@@ -299,29 +298,8 @@ u64 tnum_step(struct tnum t, u64 z)
 	if (z < t.value)
 		return t.value;
 
-	/* keep t's known bits, and match all unknown bits to z */
-	j = t.value | (z & t.mask);
-
-	if (j > z) {
-		p = ~z & t.value & ~t.mask;
-		k = fls64(p); /* k is the most-significant 0-to-1 flip */
-		q = U64_MAX << k;
-		r = q & z; /* positions > k matched to z */
-		s = ~q & t.value; /* positions <= k matched to t.value */
-		v = r | s;
-		res = v;
-	} else {
-		p = z & ~t.value & ~t.mask;
-		k = fls64(p); /* k is the most-significant 1-to-0 flip */
-		q = U64_MAX << k;
-		r = q & t.mask & z; /* unknown positions > k, matched to z */
-		s = q & ~t.mask; /* known positions > k, set to 1 */
-		v = r | s;
-		/* add 1 to unknown positions > k to make value greater than z */
-		u = v + (1ULL << k);
-		/* extract bits in unknown positions > k from u, rest from t.value */
-		w = (u & t.mask) | t.value;
-		res = w;
-	}
-	return res;
+	d = z - t.value;
+	carry_mask = (1ULL << fls64(d & ~t.mask)) - 1;
+	inc = ((d | carry_mask | ~t.mask) + 1) & t.mask;
+	return t.value | inc;
 }
-- 
2.34.1


^ permalink raw reply related	[flat|nested] 13+ messages in thread

end of thread, other threads:[~2026-03-20  2:05 UTC | newest]

Thread overview: 13+ messages (download: mbox.gz follow: Atom feed
-- links below jump to the message on this page --
2026-03-18 17:19 [PATCH] bpf: Simplify tnum_step() Hao Sun
2026-03-19  5:22 ` Shung-Hsi Yu
2026-03-19  9:01   ` Hao Sun
2026-03-19  9:35     ` Paul Chaignon
2026-03-19 13:12       ` Hao Sun
2026-03-19 13:24         ` Paul Chaignon
2026-03-20  2:04       ` Shung-Hsi Yu
2026-03-19  7:24 ` Eduard Zingerman
2026-03-19  9:02   ` Hao Sun
2026-03-19  8:17 ` Kumar Kartikeya Dwivedi
2026-03-19  9:06   ` Hao Sun
2026-03-19 17:38     ` Eduard Zingerman
2026-03-19 19:41       ` Hao Sun

This is a public inbox, see mirroring instructions
for how to clone and mirror all data and code used for this inbox