diff --git a/tinygrad/uop/validate.py b/tinygrad/uop/validate.py index 10b3d40faf..c81da5a48e 100644 --- a/tinygrad/uop/validate.py +++ b/tinygrad/uop/validate.py @@ -9,14 +9,14 @@ if z3.get_version() < (4, 12, 4, 0): raise ImportError("bounds checking requires z3 >= 4.12.4, use CHECK_OOB=0 to disable, or \"pip install 'z3-solver>=4.12.4\"") # IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND -def z3_cdiv(a, b):return z3.If((a<0), z3.If(0 z3.ArithRef:return z3.If((a<0), z3.If(0 z3.BoolRef: assert isinstance(a, z3.BoolRef), f"{type(a)=}, {a=}" return a^b -z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()), - Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor, - Ops.MAX: lambda a,b: z3.If(a tuple[z3.ArithRef, z3.BoolRef]: +z3_alu: dict[Ops, Callable[..., z3.ExprRef]] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, + Ops.SHR: lambda a,b: a/(2**b.as_long()), Ops.SHL: lambda a,b: a*(2**b.as_long()), + Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor, Ops.MAX: lambda a,b: z3.If(a tuple[z3.ArithRef, z3.BoolRef]: return (s:=z3.Int(name, ctx=solver.ctx)), (vmin <= s)&(s <= vmax) z3_renderer = PatternMatcher([ @@ -44,10 +44,10 @@ z3_renderer = PatternMatcher([ (UPat(GroupOp.ALU, name="x"), lambda x,ctx: (z3_alu[x.op](*(ctx[1][s] for s in x.src)), None)), ]) -def uops_to_z3(solver, *uops: UOp) -> list[z3.ExprRef]: +def uops_to_z3(solver:z3.Solver, *uops: UOp) -> list[z3.ExprRef]: lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.index) or x.op is Ops.SINK))[:-1] z3map: dict[UOp, z3.ExprRef] = {} - for i,u in enumerate(lst): + for u in lst: z3_rewritten = z3_renderer.rewrite(u, ctx=(solver, z3map)) if z3_rewritten is None: raise NotImplementedError(f"{u.op} is not supported by z3") new_u, constraint = cast(tuple[z3.ArithRef, z3.BoolRef|None], z3_rewritten)