cleanups from if range branch (#12973)

This commit is contained in:
George Hotz
2025-10-28 20:58:47 +08:00
committed by GitHub
parent 901d27b3ba
commit e936aa7974
5 changed files with 14 additions and 10 deletions

View File

@@ -294,6 +294,7 @@ jobs:
spec:
strategy:
fail-fast: false
matrix:
group: [1, 2]
name: SPEC=2 (${{ matrix.group }})

View File

@@ -85,7 +85,9 @@ def word_wrap(x, wrap=80):
while len(ansistrip(x[:i])) < wrap and i < len(x): i += 1
return x[:i] + "\n" + word_wrap(x[i:], wrap)
def pad_bytes(b:bytes, align:int) -> bytes: return b + b'\x00' * ((align - (len(b) % align)) % align)
def panic(e:Exception): raise e
def panic(e:Exception|None=None):
if e is None: raise RuntimeError("PANIC!")
raise e
@functools.cache
def canonicalize_strides(shape:tuple[T, ...], strides:tuple[T, ...]) -> tuple[T, ...]:

View File

@@ -870,6 +870,7 @@ class UPat(MathTrait):
def fuse(self): return self.alu(Ops.FUSE)
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs)
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
def alu(self, op:Ops, *src:UPat):

View File

@@ -121,10 +121,11 @@ program_spec = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True),
# LOAD(idx) / LOAD (idx, alt_value) / STORE(idx, val)
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx").or_casted(), )), validate_index),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx").or_casted(), UPat())), validate_index),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="idx").or_casted(), UPat())), validate_index),
# LOAD (idx, alt_value) / STORE(if gated) / LOAD(idx) / STORE(idx, val)
(UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().load(UPat()), validate_index),
(UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().store(UPat()), validate_index),
(UPat().index(UPat(), name="idx").or_casted().load(), validate_index),
(UPat().index(UPat(), name="idx").or_casted().store(UPat()), validate_index),
# RANGE/SPECIAL define loops, END closes them
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
@@ -160,8 +161,8 @@ kernel_spec = PatternMatcher([
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
# LOAD(idx) / STORE(idx, val) -- NOTE: we do this here to not run validate_index since z3 doesn't support Invalid
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), )), lambda: True),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX).or_casted(), UPat())), lambda: True),
(UPat(Ops.INDEX).or_casted().load(), lambda: True),
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
# UNROLL/CONTRACT is used here for WMMA
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),

View File

@@ -61,19 +61,18 @@ def validate_index(idx:UOp, gate:UOp|None=None):
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
mask = idx.src[2]&gate if len(idx.src)==3 else gate
# WEBGPU has a BITCAST in the index. TODO: fix
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
solver = z3.Solver(ctx=z3.Context())
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], mask)
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], gate)
solver.add(z3_mask)
with cpu_profile("validate index with z3", "TINY"):
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
print(f"idx={idx.src[1].render(simplify=False)}")
print(f"mask & gate={mask.render(simplify=False)}")
print(f"gate={gate.render(simplify=False)}")
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
return False
return True