mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
types for validate.py (#14422)
This commit is contained in:
@@ -9,14 +9,14 @@ if z3.get_version() < (4, 12, 4, 0):
|
||||
raise ImportError("bounds checking requires z3 >= 4.12.4, use CHECK_OOB=0 to disable, or \"pip install 'z3-solver>=4.12.4\"")
|
||||
|
||||
# IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND
|
||||
def z3_cdiv(a, b):return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
|
||||
def z3_xor(a,b):
|
||||
def z3_cdiv(a:z3.ArithRef, b:z3.ArithRef) -> z3.ArithRef:return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
|
||||
def z3_xor(a:z3.BoolRef, b:z3.BoolRef) -> z3.BoolRef:
|
||||
assert isinstance(a, z3.BoolRef), f"{type(a)=}, {a=}"
|
||||
return a^b
|
||||
z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
|
||||
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor,
|
||||
Ops.MAX: lambda a,b: z3.If(a<b, b, a),}
|
||||
def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> tuple[z3.ArithRef, z3.BoolRef]:
|
||||
z3_alu: dict[Ops, Callable[..., z3.ExprRef]] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv,
|
||||
Ops.SHR: lambda a,b: a/(2**b.as_long()), Ops.SHL: lambda a,b: a*(2**b.as_long()),
|
||||
Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor, Ops.MAX: lambda a,b: z3.If(a<b, b, a),}
|
||||
def create_bounded(name:str, vmin:int, vmax:int, solver:z3.Solver) -> tuple[z3.ArithRef, z3.BoolRef]:
|
||||
return (s:=z3.Int(name, ctx=solver.ctx)), (vmin <= s)&(s <= vmax)
|
||||
|
||||
z3_renderer = PatternMatcher([
|
||||
@@ -44,10 +44,10 @@ z3_renderer = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x,ctx: (z3_alu[x.op](*(ctx[1][s] for s in x.src)), None)),
|
||||
])
|
||||
|
||||
def uops_to_z3(solver, *uops: UOp) -> list[z3.ExprRef]:
|
||||
def uops_to_z3(solver:z3.Solver, *uops: UOp) -> list[z3.ExprRef]:
|
||||
lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.index) or x.op is Ops.SINK))[:-1]
|
||||
z3map: dict[UOp, z3.ExprRef] = {}
|
||||
for i,u in enumerate(lst):
|
||||
for u in lst:
|
||||
z3_rewritten = z3_renderer.rewrite(u, ctx=(solver, z3map))
|
||||
if z3_rewritten is None: raise NotImplementedError(f"{u.op} is not supported by z3")
|
||||
new_u, constraint = cast(tuple[z3.ArithRef, z3.BoolRef|None], z3_rewritten)
|
||||
|
||||
Reference in New Issue
Block a user