mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
cleanups from if range branch (#12973)
This commit is contained in:
1
.github/workflows/test.yml
vendored
1
.github/workflows/test.yml
vendored
@@ -294,6 +294,7 @@ jobs:
|
||||
|
||||
spec:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
group: [1, 2]
|
||||
name: SPEC=2 (${{ matrix.group }})
|
||||
|
||||
@@ -85,7 +85,9 @@ def word_wrap(x, wrap=80):
|
||||
while len(ansistrip(x[:i])) < wrap and i < len(x): i += 1
|
||||
return x[:i] + "\n" + word_wrap(x[i:], wrap)
|
||||
def pad_bytes(b:bytes, align:int) -> bytes: return b + b'\x00' * ((align - (len(b) % align)) % align)
|
||||
def panic(e:Exception): raise e
|
||||
def panic(e:Exception|None=None):
|
||||
if e is None: raise RuntimeError("PANIC!")
|
||||
raise e
|
||||
|
||||
@functools.cache
|
||||
def canonicalize_strides(shape:tuple[T, ...], strides:tuple[T, ...]) -> tuple[T, ...]:
|
||||
|
||||
@@ -870,6 +870,7 @@ class UPat(MathTrait):
|
||||
def fuse(self): return self.alu(Ops.FUSE)
|
||||
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs)
|
||||
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
|
||||
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
||||
def alu(self, op:Ops, *src:UPat):
|
||||
|
||||
@@ -121,10 +121,11 @@ program_spec = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True),
|
||||
|
||||
# LOAD(idx) / LOAD (idx, alt_value) / STORE(idx, val)
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx").or_casted(), )), validate_index),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx").or_casted(), UPat())), validate_index),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="idx").or_casted(), UPat())), validate_index),
|
||||
# LOAD (idx, alt_value) / STORE(if gated) / LOAD(idx) / STORE(idx, val)
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().load(UPat()), validate_index),
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().store(UPat()), validate_index),
|
||||
(UPat().index(UPat(), name="idx").or_casted().load(), validate_index),
|
||||
(UPat().index(UPat(), name="idx").or_casted().store(UPat()), validate_index),
|
||||
|
||||
# RANGE/SPECIAL define loops, END closes them
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
|
||||
@@ -160,8 +161,8 @@ kernel_spec = PatternMatcher([
|
||||
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# LOAD(idx) / STORE(idx, val) -- NOTE: we do this here to not run validate_index since z3 doesn't support Invalid
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), )), lambda: True),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX).or_casted(), UPat())), lambda: True),
|
||||
(UPat(Ops.INDEX).or_casted().load(), lambda: True),
|
||||
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
|
||||
|
||||
# UNROLL/CONTRACT is used here for WMMA
|
||||
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
@@ -61,19 +61,18 @@ def validate_index(idx:UOp, gate:UOp|None=None):
|
||||
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True
|
||||
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
|
||||
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
|
||||
mask = idx.src[2]&gate if len(idx.src)==3 else gate
|
||||
|
||||
# WEBGPU has a BITCAST in the index. TODO: fix
|
||||
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
|
||||
|
||||
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
|
||||
solver = z3.Solver(ctx=z3.Context())
|
||||
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], mask)
|
||||
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], gate)
|
||||
solver.add(z3_mask)
|
||||
with cpu_profile("validate index with z3", "TINY"):
|
||||
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
|
||||
print(f"idx={idx.src[1].render(simplify=False)}")
|
||||
print(f"mask & gate={mask.render(simplify=False)}")
|
||||
print(f"gate={gate.render(simplify=False)}")
|
||||
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
|
||||
return False
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user