mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Extend fast_idiv to negative ints (#11632)
* fast idiv for signed ints * Add rule and test * fix tests * redo fuzz_fast_idiv to do negative ints as well * adjust comments * remove unused imports
This commit is contained in:
43
test/external/fuzz_fast_idiv.py
vendored
43
test/external/fuzz_fast_idiv.py
vendored
@@ -1,33 +1,26 @@
|
||||
import random
|
||||
from z3 import Int, Solver, sat
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, graph_rewrite, PatternMatcher
|
||||
import z3
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.spec import z3_renderer, z3_cdiv
|
||||
from tinygrad.uop.ops import UOp, graph_rewrite
|
||||
from tinygrad.uop.transcendental import fast_idiv
|
||||
random.seed(42)
|
||||
|
||||
z3_renderer = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
||||
# Because fast_idiv only works for non-negative integers we can emulate machine arithmetic with modulo operations.
|
||||
(UPat(Ops.SHR, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(({x.src[0].arg}/(2**{x.src[1].arg}))%{dtypes.max(x.dtype)+1})")),
|
||||
(UPat(Ops.MUL, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(({x.src[0].arg}*{x.src[1].arg})%{dtypes.max(x.dtype)+1})")),
|
||||
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
||||
(UPat(Ops.CAST, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}")),
|
||||
])
|
||||
|
||||
def render(self) -> str:
|
||||
ret = graph_rewrite(self.simplify(), z3_renderer)
|
||||
return ret.arg if ret.op is Ops.NOOP else str(ret)
|
||||
|
||||
powers_of_two = [2**i for i in range(64)]
|
||||
if __name__ == "__main__":
|
||||
x = Int('x')
|
||||
for _ in range(10_000):
|
||||
for i in range(10_000):
|
||||
if i % 1000 == 0:
|
||||
print(f"Progress: {i}")
|
||||
dt = random.choice(dtypes.ints)
|
||||
u = UOp(Ops.DEFINE_VAR, dt, arg=('x', 0, random.randint(1, dtypes.max(dt))), src=())
|
||||
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
|
||||
d = random.randint(1, max(1, u.arg[2]))
|
||||
|
||||
expr = fast_idiv(Device[Device.DEFAULT].renderer, u, d)
|
||||
if d in powers_of_two: continue
|
||||
expr = fast_idiv(None, u, d)
|
||||
if expr is None: continue
|
||||
solver = Solver()
|
||||
solver.add(x>=u.arg[1], x<=u.arg[2])
|
||||
if solver.check(eval(render(expr)) != x/d) == sat:
|
||||
assert False, f"Failed: {render(expr)} != x//{d} at x={solver.model()[x]}\nx={u}\nd={d}"
|
||||
|
||||
solver = z3.Solver()
|
||||
z3_sink = graph_rewrite(expr.sink(u), z3_renderer, ctx=(solver, {}))
|
||||
z3_expr, x = z3_sink.src[0].arg, z3_sink.src[1].arg
|
||||
|
||||
if solver.check(z3_expr != z3_cdiv(x, d)) == z3.sat:
|
||||
assert False, f"Failed: {expr.render()} != x//{d} at x={solver.model()}\nx={u}\nd={d}\n{z3_expr=}\n{x/d=}"
|
||||
|
||||
@@ -280,12 +280,15 @@ def magicgu(vmax:int, d:int) -> tuple[int,int]:
|
||||
assert False
|
||||
|
||||
def fast_idiv(device: str, x: UOp, d: int) -> UOp|None:
|
||||
# idiv is truncated division, but arithmetic shift is floored division, so can only do non-negative numbers!
|
||||
if x.vmin<0: return None
|
||||
sign = 1 if d > 0 else -1
|
||||
m,s = magicgu(vmax := min(x.vmax, dtypes.max(x.dtype)), abs(d))
|
||||
if m * vmax <= dtypes.max(x.dtype): return sign * ((x*m) >> s)
|
||||
# promo_lattice needs to return an unsigned type
|
||||
# If d is a power of two this is not valid for signed ints!
|
||||
is_unsigned = True if x.vmin>=0 or x.dtype in dtypes.uints else False
|
||||
assert d>0, "Sign should have been taken out of divisor"
|
||||
vmin,vmax = max(x.vmin, x.dtype.min), min(x.vmax, x.dtype.max)
|
||||
m,s = magicgu(max(vmax, abs(vmin)), d)
|
||||
if m*vmin >= dtypes.min(x.dtype) and m*vmax <= dtypes.max(x.dtype):
|
||||
return ((x*m) >> s) if is_unsigned else ((x*m) >> s) + (x<0).where(x.ufix(1), 0)
|
||||
# promo_lattice needs to return an unsigned type if the type is unsigned
|
||||
if dtypes.is_int(next_dtype := promo_lattice[x.dtype][-1]) and is_dtype_supported(next_dtype, None if device=='' else device):
|
||||
if m * vmax <= dtypes.max(next_dtype): return sign * ((x.cast(next_dtype)*m) >> s).cast(x.dtype)
|
||||
if m*vmin >= dtypes.min(next_dtype) and m*vmax <= dtypes.max(next_dtype):
|
||||
return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0)
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user