Extend fast_idiv to negative ints (#11632)

* fast idiv for signed ints

* Add rule and test

* fix tests

* redo fuzz_fast_idiv to do negative ints as well

* adjust comments

* remove unused imports
This commit is contained in:
Sieds Lykles
2025-08-12 19:34:49 +02:00
committed by GitHub
parent 17adbe86d8
commit 4d6e407eb0
2 changed files with 28 additions and 32 deletions

View File

@@ -1,33 +1,26 @@
import random
from z3 import Int, Solver, sat
from tinygrad import dtypes, Device
from tinygrad.uop.ops import UOp, Ops, UPat, graph_rewrite, PatternMatcher
import z3
from tinygrad import dtypes
from tinygrad.uop.spec import z3_renderer, z3_cdiv
from tinygrad.uop.ops import UOp, graph_rewrite
from tinygrad.uop.transcendental import fast_idiv
random.seed(42)
z3_renderer = PatternMatcher([
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
# Because fast_idiv only works for non-negative integers we can emulate machine arithmetic with modulo operations.
(UPat(Ops.SHR, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(({x.src[0].arg}/(2**{x.src[1].arg}))%{dtypes.max(x.dtype)+1})")),
(UPat(Ops.MUL, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(({x.src[0].arg}*{x.src[1].arg})%{dtypes.max(x.dtype)+1})")),
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
(UPat(Ops.CAST, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}")),
])
def render(self) -> str:
ret = graph_rewrite(self.simplify(), z3_renderer)
return ret.arg if ret.op is Ops.NOOP else str(ret)
powers_of_two = [2**i for i in range(64)]
if __name__ == "__main__":
x = Int('x')
for _ in range(10_000):
for i in range(10_000):
if i % 1000 == 0:
print(f"Progress: {i}")
dt = random.choice(dtypes.ints)
u = UOp(Ops.DEFINE_VAR, dt, arg=('x', 0, random.randint(1, dtypes.max(dt))), src=())
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
d = random.randint(1, max(1, u.arg[2]))
expr = fast_idiv(Device[Device.DEFAULT].renderer, u, d)
if d in powers_of_two: continue
expr = fast_idiv(None, u, d)
if expr is None: continue
solver = Solver()
solver.add(x>=u.arg[1], x<=u.arg[2])
if solver.check(eval(render(expr)) != x/d) == sat:
assert False, f"Failed: {render(expr)} != x//{d} at x={solver.model()[x]}\nx={u}\nd={d}"
solver = z3.Solver()
z3_sink = graph_rewrite(expr.sink(u), z3_renderer, ctx=(solver, {}))
z3_expr, x = z3_sink.src[0].arg, z3_sink.src[1].arg
if solver.check(z3_expr != z3_cdiv(x, d)) == z3.sat:
assert False, f"Failed: {expr.render()} != x//{d} at x={solver.model()}\nx={u}\nd={d}\n{z3_expr=}\n{x/d=}"

View File

@@ -280,12 +280,15 @@ def magicgu(vmax:int, d:int) -> tuple[int,int]:
assert False
def fast_idiv(device: str, x: UOp, d: int) -> UOp|None:
# idiv is truncated division, but arithmetic shift is floored division, so can only do non-negative numbers!
if x.vmin<0: return None
sign = 1 if d > 0 else -1
m,s = magicgu(vmax := min(x.vmax, dtypes.max(x.dtype)), abs(d))
if m * vmax <= dtypes.max(x.dtype): return sign * ((x*m) >> s)
# promo_lattice needs to return an unsigned type
# If d is a power of two this is not valid for signed ints!
is_unsigned = True if x.vmin>=0 or x.dtype in dtypes.uints else False
assert d>0, "Sign should have been taken out of divisor"
vmin,vmax = max(x.vmin, x.dtype.min), min(x.vmax, x.dtype.max)
m,s = magicgu(max(vmax, abs(vmin)), d)
if m*vmin >= dtypes.min(x.dtype) and m*vmax <= dtypes.max(x.dtype):
return ((x*m) >> s) if is_unsigned else ((x*m) >> s) + (x<0).where(x.ufix(1), 0)
# promo_lattice needs to return an unsigned type if the type is unsigned
if dtypes.is_int(next_dtype := promo_lattice[x.dtype][-1]) and is_dtype_supported(next_dtype, None if device=='' else device):
if m * vmax <= dtypes.max(next_dtype): return sign * ((x.cast(next_dtype)*m) >> s).cast(x.dtype)
if m*vmin >= dtypes.min(next_dtype) and m*vmax <= dtypes.max(next_dtype):
return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0)
return None