From dbe8f034a7a5c359dd3fbf90636243fbe4bae79b Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 29 Jan 2026 11:11:47 -0500 Subject: [PATCH] pass z3.Context in validate ctx [pr] (#14423) does not need to pass the whole solver --- tinygrad/uop/validate.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tinygrad/uop/validate.py b/tinygrad/uop/validate.py index c81da5a48e..ae80f5a831 100644 --- a/tinygrad/uop/validate.py +++ b/tinygrad/uop/validate.py @@ -16,8 +16,8 @@ def z3_xor(a:z3.BoolRef, b:z3.BoolRef) -> z3.BoolRef: z3_alu: dict[Ops, Callable[..., z3.ExprRef]] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()), Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor, Ops.MAX: lambda a,b: z3.If(a tuple[z3.ArithRef, z3.BoolRef]: - return (s:=z3.Int(name, ctx=solver.ctx)), (vmin <= s)&(s <= vmax) +def create_bounded(name:str, vmin:int, vmax:int, z3ctx:z3.Context) -> tuple[z3.ArithRef, z3.BoolRef]: + return (s:=z3.Int(name, ctx=z3ctx)), (vmin <= s)&(s <= vmax) z3_renderer = PatternMatcher([ (UPat.var("cond").where(UPat.var("x"), UPat.const(dtypes.index, Invalid)), lambda x,cond,ctx: (ctx[1][x], ctx[1][cond])), @@ -27,16 +27,16 @@ z3_renderer = PatternMatcher([ (UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(x.render(simplify=False), 0, ctx[1][x.src[0]]-1, ctx[0])), # loads are variables bounded by the min/max of the dtype (UPat(Ops.LOAD, dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: create_bounded(f"load{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])), - (UPat(Ops.LOAD, dtypes.bool, name="x"), lambda x,ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0].ctx), None)), + (UPat(Ops.LOAD, dtypes.bool, name="x"), lambda x,ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)), # constants - (UPat(Ops.CONST, arg=Invalid, name="x"), lambda x,ctx: (z3.Int("Invalid", ctx=ctx[0].ctx), None)), - (UPat(Ops.CONST, dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0].ctx), None)), - (UPat(Ops.CONST, dtypes.bool, name="x"), lambda x,ctx: (z3.BoolVal(x.arg, ctx=ctx[0].ctx), None)), + (UPat(Ops.CONST, arg=Invalid, name="x"), lambda x,ctx: (z3.Int("Invalid", ctx=ctx[0]), None)), + (UPat(Ops.CONST, dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0]), None)), + (UPat(Ops.CONST, dtypes.bool, name="x"), lambda x,ctx: (z3.BoolVal(x.arg, ctx=ctx[0]), None)), # casts from floats create new variables (UPat(Ops.CAST, dtypes.ints+(dtypes.index,), src=(UPat(dtype=dtypes.floats),), name="x"), lambda x,ctx: create_bounded(f"cast{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])), # A comparison between floats introduces a new bool variable - (UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx: (z3.Bool(f"float_cmp{len(ctx[1])}", ctx=ctx[0].ctx), None)), + (UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx: (z3.Bool(f"float_cmp{len(ctx[1])}", ctx=ctx[0]), None)), # casts from bool/int to int/bool (UPat(Ops.CAST, dtypes.ints+(dtypes.index,),src=(UPat.var("x", dtypes.bool),), name="c"), lambda x,c,ctx: (z3.If(ctx[1][x], 1, 0), None)), (UPat(Ops.CAST, dtypes.ints+(dtypes.index,), src=(UPat.var("x", dtypes.ints+(dtypes.index,)),), name="c"), lambda x,c,ctx: (ctx[1][x], None)), @@ -48,7 +48,7 @@ def uops_to_z3(solver:z3.Solver, *uops: UOp) -> list[z3.ExprRef]: lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.index) or x.op is Ops.SINK))[:-1] z3map: dict[UOp, z3.ExprRef] = {} for u in lst: - z3_rewritten = z3_renderer.rewrite(u, ctx=(solver, z3map)) + z3_rewritten = z3_renderer.rewrite(u, ctx=(solver.ctx, z3map)) if z3_rewritten is None: raise NotImplementedError(f"{u.op} is not supported by z3") new_u, constraint = cast(tuple[z3.ArithRef, z3.BoolRef|None], z3_rewritten) if constraint is not None: solver.add(constraint)