mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
clean up the spec (#12868)
* tighten up the spec * move validate into a different file * that moved to validate * after(barr)
This commit is contained in:
4
test/external/external_benchmark_schedule.py
vendored
4
test/external/external_benchmark_schedule.py
vendored
@@ -4,7 +4,7 @@ from tinygrad.helpers import Profiling, Timing, getenv
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.codegen import full_rewrite_to_sink
|
||||
from tinygrad.codegen.late.control_flow import linearize
|
||||
from tinygrad.uop.spec import type_verify
|
||||
from tinygrad.uop.spec import type_verify, program_spec
|
||||
|
||||
if __name__ == "__main__":
|
||||
mdl = ResNet50()
|
||||
@@ -41,5 +41,5 @@ if __name__ == "__main__":
|
||||
for u in rewritten_uops:
|
||||
uops_line.append(linearize(u))
|
||||
with Timing("***** model verify in "):
|
||||
for u in uops_line: type_verify(u)
|
||||
for u in uops_line: type_verify(u, program_spec)
|
||||
print(sum(len(u) for u in uops_line))
|
||||
|
||||
2
test/external/fuzz_fast_idiv.py
vendored
2
test/external/fuzz_fast_idiv.py
vendored
@@ -1,7 +1,7 @@
|
||||
import random
|
||||
import z3
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.spec import uops_to_z3, z3_cdiv
|
||||
from tinygrad.uop.validate import uops_to_z3, z3_cdiv
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.uop.decompositions import fast_idiv
|
||||
random.seed(42)
|
||||
|
||||
2
test/external/fuzz_symbolic.py
vendored
2
test/external/fuzz_symbolic.py
vendored
@@ -2,7 +2,7 @@ import random, operator
|
||||
import z3
|
||||
from tinygrad import Variable, dtypes
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.uop.spec import uops_to_z3
|
||||
from tinygrad.uop.validate import uops_to_z3
|
||||
from tinygrad.helpers import DEBUG, Context
|
||||
|
||||
seed = random.randint(0, 100)
|
||||
|
||||
@@ -639,13 +639,13 @@ class TestUOpGraph(unittest.TestCase):
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.index(UOp.invalid()), barrier))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+2, UOp.const(dtypes.bool, True)), barrier))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(lidx+2, UOp.const(dtypes.bool, True)),))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
|
||||
|
||||
ld0 = uops[-1].src[-1]
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0.src[0], smem.index(lidx+2))
|
||||
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2))
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.helpers import CI, DEBUG, getenv, Timing
|
||||
from tinygrad.dtype import dtypes, DType, AddrSpace
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401
|
||||
from tinygrad.uop.spec import spec
|
||||
from tinygrad.uop.spec import shared_spec
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.codegen import full_rewrite
|
||||
@@ -332,7 +332,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
|
||||
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
|
||||
|
||||
# NOTE: webgpu specific, since only webgpu performs bitpacking
|
||||
@@ -342,7 +342,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
|
||||
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
|
||||
|
||||
# NOTE: webgpu specific, since only webgpu performs bitpacking
|
||||
@@ -513,7 +513,7 @@ class TestUOpStr(unittest.TestCase):
|
||||
class TestUPatHelpers(unittest.TestCase):
|
||||
def test_location(self):
|
||||
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "symbolic.py")
|
||||
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py")
|
||||
self.assertEqual(shared_spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py")
|
||||
test_upat = UPat(Ops.CONST, dtypes.bool)
|
||||
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
|
||||
test_upat_named = test_upat.named("test_name")
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer
|
||||
from tinygrad.uop.symbolic import sym, commutative
|
||||
from tinygrad.uop.spec import uops_to_z3
|
||||
from tinygrad.uop.validate import uops_to_z3
|
||||
|
||||
def check_uop_against_string(self, v:UOp, s:str):
|
||||
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL)}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
|
||||
from tinygrad.uop.spec import type_verify
|
||||
from tinygrad.uop.spec import type_verify, program_spec
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# import all pattern matchers here
|
||||
@@ -98,5 +98,5 @@ def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
|
||||
"""
|
||||
|
||||
lst = linearize(full_rewrite_to_sink(sink, ren, optimize=sink.tag is None))
|
||||
if __debug__: type_verify(lst)
|
||||
if __debug__: type_verify(lst, program_spec)
|
||||
return lst
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.helpers import suppress_finalizing
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.uop.mathtraits import MathTrait
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, srender
|
||||
from tinygrad.uop.spec import tensor_uop_spec, type_verify
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
@@ -229,7 +229,7 @@ class Tensor(MathTrait):
|
||||
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||
|
||||
# verify Tensors match the spec
|
||||
if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec)
|
||||
if __debug__: type_verify(list(big_sink.toposort()), tensor_spec)
|
||||
|
||||
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
|
||||
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")
|
||||
|
||||
@@ -1,60 +1,45 @@
|
||||
from typing import cast, Callable
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite, AxisType
|
||||
from typing import cast
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
|
||||
from tinygrad.helpers import all_same, prod, DEBUG, IGNORE_OOB, Context, cpu_profile
|
||||
try:
|
||||
import z3
|
||||
# older versions of z3 dont have some operators like & overloaded
|
||||
if z3.get_version() < (4, 12, 4, 0): raise ImportError
|
||||
from tinygrad.helpers import DEBUG, Context
|
||||
from tinygrad.uop.validate import validate_index
|
||||
|
||||
# IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND
|
||||
def z3_cdiv(a, b):return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
|
||||
def z3_xor(a,b):
|
||||
if isinstance(a, z3.BoolRef): return a^b
|
||||
assert a==-1 or b==-1, "xor can only be used in indexing if one of the aruments is -1"
|
||||
return -a-1 if b==-1 else -b-1
|
||||
z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
|
||||
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor,
|
||||
Ops.MAX: lambda a,b: z3.If(a<b, b, a), Ops.TRUNC: lambda a: a if a.is_int() else z3.ToReal(z3.If(a >= 0, z3.ToInt(a), -z3.ToInt(-a)))}
|
||||
def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef:
|
||||
s = z3.Int(name, ctx=solver.ctx)
|
||||
solver.add(vmin <= s, s <= vmax)
|
||||
return s
|
||||
# four specs:
|
||||
# shared_spec -- usable anywhere
|
||||
# tensor_spec -- usable in tensor graph
|
||||
# program_spec -- usable in linearized program
|
||||
# full_spec -- all uops ever created
|
||||
|
||||
# ctx is (solver, load_number_dict)
|
||||
# each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different
|
||||
# contexts can have the same hash but error on comparison
|
||||
z3_renderer = PatternMatcher([
|
||||
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))),
|
||||
# loaded bools become a z3 int with min max of 0-1
|
||||
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)),
|
||||
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"),
|
||||
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))),
|
||||
# z3 can cast from bool to int automatically
|
||||
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))),
|
||||
# if the source of the cast is not a noop it means that it is a float and so we create a new variable
|
||||
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=(ctx[0], create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))),
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))),
|
||||
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3_alu[x.op](*(s.arg[1] for s in x.src))))),
|
||||
# A comparison between floats introduces a new bool variable
|
||||
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"float_cmp{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))),
|
||||
])
|
||||
# *** these uops work anywhere ***
|
||||
|
||||
def uops_to_z3(solver, *uops: UOp) -> 'list[z3.ExprRef]':
|
||||
with Context(TRACK_MATCH_STATS=0, SPEC=0): # cant pickle z3 objects, and these UOps don't follow spec
|
||||
return [s.arg[1] for s in graph_rewrite(uops[0].sink(*uops[1:]), z3_renderer, ctx=(solver, {})).src]
|
||||
shared_spec = PatternMatcher([
|
||||
(UPat(Ops.SINK, dtypes.void), lambda: True), # NOTE: for testing, we let sinks be anything
|
||||
|
||||
z3_imported = True
|
||||
except (ImportError, AttributeError): z3_imported = False
|
||||
# CONST/DEFINE_VAR are everywhere
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
|
||||
buffer_spec = PatternMatcher([
|
||||
# ALUs: most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
||||
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
||||
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
|
||||
# and SHL/SHR, the shift distance can be an int
|
||||
(UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
|
||||
(UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
|
||||
|
||||
# CAST
|
||||
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
||||
|
||||
# RANGE can be in the big graph now
|
||||
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
|
||||
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
||||
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
||||
])
|
||||
|
||||
# ***** UOp spec in the Tensor graph *****
|
||||
|
||||
tensor_spec = PatternMatcher([
|
||||
# buffer spec
|
||||
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
||||
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
|
||||
isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))),
|
||||
@@ -63,9 +48,7 @@ buffer_spec = PatternMatcher([
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
|
||||
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True),
|
||||
])
|
||||
|
||||
assign_spec = PatternMatcher([
|
||||
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
|
||||
@@ -77,11 +60,7 @@ assign_spec = PatternMatcher([
|
||||
|
||||
# MSTACK combines buffers into multi
|
||||
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
|
||||
])
|
||||
|
||||
# *** this is the spec of a Tensor in UOp ***
|
||||
|
||||
tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
|
||||
@@ -109,127 +88,69 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
|
||||
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
|
||||
|
||||
# REDUCE_AXIS is the reduce in the tensor graph
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
||||
|
||||
# REDUCE with an outerworld range
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||
|
||||
# AFTER if things were kernelized
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True)
|
||||
])
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
||||
])+shared_spec
|
||||
|
||||
# ***** uop type spec *****
|
||||
# ***** UOp spec in linearized programs *****
|
||||
|
||||
def validate_index(idx:UOp, gate:UOp|None=None):
|
||||
if gate is None: gate = UOp.const(dtypes.bool, True)
|
||||
# TODO: check for overflow
|
||||
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True
|
||||
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
|
||||
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
|
||||
mask = idx.src[2]&gate if len(idx.src)==3 else gate
|
||||
|
||||
# WEBGPU has a BITCAST in the index. TODO: fix
|
||||
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
|
||||
|
||||
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
|
||||
solver = z3.Solver(ctx=z3.Context())
|
||||
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], mask)
|
||||
solver.add(z3_mask)
|
||||
with cpu_profile("validate index with z3", "TINY"):
|
||||
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
|
||||
print(f"idx={idx.src[1].render(simplify=False)}")
|
||||
print(f"mask & gate={mask.render(simplify=False)}")
|
||||
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_store(idx:UOp, val:UOp, gate:UOp|None=None):
|
||||
if gate is None: gate = UOp.const(dtypes.bool, True)
|
||||
if gate.op is Ops.IF: gate = gate.src[0]
|
||||
# we need to find the implicit gates, inverse of delete_redundant_gates
|
||||
for u in val.toposort():
|
||||
if u.op is Ops.IF: gate &= u.src[0]
|
||||
return validate_index(idx, gate)
|
||||
|
||||
index_pat = UPat(Ops.INDEX, name="idx").or_casted()
|
||||
|
||||
# this is the matcher for the final rendered UOps
|
||||
# matcher functions returns True or False (or None to not match)
|
||||
spec = PatternMatcher([
|
||||
program_spec = PatternMatcher([
|
||||
# DEFINEs
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
|
||||
(UPat(Ops.DEFINE_REG, src=()), lambda: True),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
|
||||
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
|
||||
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
||||
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
|
||||
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||
|
||||
# allow AFTER on buffers
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True),
|
||||
|
||||
# **** new style load/store ****
|
||||
# INDEX is used in new style load/store
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True),
|
||||
|
||||
# LOAD (idx, alt_value) / LOAD(idx) / STORE(idx, val)
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx").or_casted(), UPat((Ops.VECTORIZE, Ops.VCONST, Ops.CONST)))), validate_index),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx").or_casted(), )), validate_index),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="idx").or_casted(), UPat())), validate_index),
|
||||
|
||||
# RANGE/SPECIAL define loops, END closes them
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
|
||||
(UPat(Ops.END, src=(UPat(Ops.RANGE), UPat()), allow_any_len=True, arg=1, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# make sure all index dtypes have been lowered
|
||||
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
|
||||
(UPat(Ops.CONST, arg=Invalid), lambda: False),
|
||||
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)),
|
||||
|
||||
# INDEX is used in new style load/store
|
||||
# INDEX takes a <buf, alu, gate?>
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
|
||||
# LOAD takes a <bufidx, alt?, barrier?>
|
||||
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])),
|
||||
(UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index),
|
||||
|
||||
# STORE takes a <bufidx, val, ranges...>
|
||||
(UPat(Ops.STORE, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store),
|
||||
|
||||
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
||||
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
||||
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
|
||||
# and SHL/SHR, the shift distance can be an int
|
||||
(UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
|
||||
(UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
|
||||
|
||||
(UPat(Ops.END, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# WMMA has a <a, b, acc>
|
||||
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
||||
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# if has a <gate, barrier?>
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),), allow_any_len=True), lambda: True),
|
||||
# if has a <gate, index_for_dedup>
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),
|
||||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
||||
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
# VECTORIZE/GEP
|
||||
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
|
||||
# BARRIER
|
||||
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
||||
(UPat(Ops.BARRIER, dtypes.void), lambda: True), # BARRIERs can also happen at the end of loops
|
||||
|
||||
# NOTE: for testing, we let sinks be anything
|
||||
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
|
||||
(UPat(Ops.SINK, dtypes.void), lambda: True),
|
||||
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
|
||||
])
|
||||
|
||||
# *** this is the UOp AST spec ***
|
||||
|
||||
ast_spec = PatternMatcher([
|
||||
# all parent UOps must have the same shape
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
|
||||
])
|
||||
])+shared_spec
|
||||
|
||||
# *** this spec should match all UOps ever created ***
|
||||
|
||||
full_spec = PatternMatcher([
|
||||
# any END
|
||||
(UPat(Ops.END), lambda: True),
|
||||
|
||||
# SENTINEL should never be in the graph
|
||||
(UPat(Ops.SENTINEL), lambda: False),
|
||||
|
||||
@@ -254,6 +175,8 @@ full_spec = PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat(GroupOp.Movement))), lambda: True),
|
||||
|
||||
# expander: unroll/contract/gep/ptrcat/cat
|
||||
#(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
#(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat((Ops.UNROLL, Ops.CONTRACT), src=(UPat(),)), lambda: True),
|
||||
# GEP multi is supported here
|
||||
(UPat(Ops.GEP, name="gep"), lambda gep: gep.dtype is dtypes.void or gep.dtype.vcount == len(gep.arg)),
|
||||
@@ -281,12 +204,11 @@ full_spec = PatternMatcher([
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True),
|
||||
# allow any AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True),
|
||||
])+tensor_uop_spec+spec
|
||||
])+tensor_spec+program_spec
|
||||
|
||||
# ***** uop helpers *****
|
||||
|
||||
def type_verify(uops:list[UOp], extra_spec:PatternMatcher|None=None):
|
||||
check_spec = (extra_spec+spec) if extra_spec is not None else spec
|
||||
def type_verify(uops:list[UOp], check_spec:PatternMatcher):
|
||||
for i,u in enumerate(uops):
|
||||
with Context(TRACK_MATCH_STATS=0): ret = check_spec.rewrite(u)
|
||||
if cast(bool|None, ret) is not True:
|
||||
|
||||
79
tinygrad/uop/validate.py
Normal file
79
tinygrad/uop/validate.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Callable
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, python_alu, graph_rewrite
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.helpers import IGNORE_OOB, Context, cpu_profile
|
||||
|
||||
try:
|
||||
import z3
|
||||
# older versions of z3 dont have some operators like & overloaded
|
||||
if z3.get_version() < (4, 12, 4, 0): raise ImportError
|
||||
|
||||
# IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND
|
||||
def z3_cdiv(a, b):return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
|
||||
def z3_xor(a,b):
|
||||
if isinstance(a, z3.BoolRef): return a^b
|
||||
assert a==-1 or b==-1, "xor can only be used in indexing if one of the aruments is -1"
|
||||
return -a-1 if b==-1 else -b-1
|
||||
z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
|
||||
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor,
|
||||
Ops.MAX: lambda a,b: z3.If(a<b, b, a), Ops.TRUNC: lambda a: a if a.is_int() else z3.ToReal(z3.If(a >= 0, z3.ToInt(a), -z3.ToInt(-a)))}
|
||||
def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef:
|
||||
s = z3.Int(name, ctx=solver.ctx)
|
||||
solver.add(vmin <= s, s <= vmax)
|
||||
return s
|
||||
|
||||
# ctx is (solver, load_number_dict)
|
||||
# each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different
|
||||
# contexts can have the same hash but error on comparison
|
||||
z3_renderer = PatternMatcher([
|
||||
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))),
|
||||
# loaded bools become a z3 int with min max of 0-1
|
||||
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)),
|
||||
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"),
|
||||
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))),
|
||||
# z3 can cast from bool to int automatically
|
||||
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))),
|
||||
# if the source of the cast is not a noop it means that it is a float and so we create a new variable
|
||||
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=(ctx[0], create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))),
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))),
|
||||
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3_alu[x.op](*(s.arg[1] for s in x.src))))),
|
||||
# A comparison between floats introduces a new bool variable
|
||||
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"float_cmp{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))),
|
||||
])
|
||||
|
||||
def uops_to_z3(solver, *uops: UOp) -> 'list[z3.ExprRef]':
|
||||
with Context(TRACK_MATCH_STATS=0, SPEC=0): # cant pickle z3 objects, and these UOps don't follow spec
|
||||
return [s.arg[1] for s in graph_rewrite(uops[0].sink(*uops[1:]), z3_renderer, ctx=(solver, {})).src]
|
||||
|
||||
z3_imported = True
|
||||
except (ImportError, AttributeError): z3_imported = False
|
||||
|
||||
def validate_index(idx:UOp, gate:UOp|None=None):
|
||||
if gate is None: gate = UOp.const(dtypes.bool, True)
|
||||
# TODO: check for overflow
|
||||
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True
|
||||
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
|
||||
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
|
||||
mask = idx.src[2]&gate if len(idx.src)==3 else gate
|
||||
|
||||
# WEBGPU has a BITCAST in the index. TODO: fix
|
||||
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
|
||||
|
||||
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
|
||||
solver = z3.Solver(ctx=z3.Context())
|
||||
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], mask)
|
||||
solver.add(z3_mask)
|
||||
with cpu_profile("validate index with z3", "TINY"):
|
||||
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
|
||||
print(f"idx={idx.src[1].render(simplify=False)}")
|
||||
print(f"mask & gate={mask.render(simplify=False)}")
|
||||
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
|
||||
return False
|
||||
return True
|
||||
Reference in New Issue
Block a user