lower the invalid gate last (#12164)

* lowering invalid gate is part of lower_index_dtype

* update test

* remove import

* put that back

* reduce_collapse uses invalid

* fix that pattern to use invalid_pat

* valid creates the right dtype count

* seperate rule for lowering invalid gate

* dont unvectorize Invalid gate

* image_fixup uses Invalid

* update tests

* cleanup

* update split_load_store

* add .scalar() there
This commit is contained in:
Sieds Lykles
2025-09-24 04:27:35 +02:00
committed by GitHub
parent ad7c8c21ea
commit 6146c64d81
7 changed files with 44 additions and 42 deletions

View File

@@ -452,10 +452,10 @@ class TestUOpGraph(unittest.TestCase):
def test_load_idx_becomes_int(self):
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 1)
l0 = UOp(Ops.LOAD, dtypes.long, (d0.index(UOp.const(dtypes.int, 0)),))
l0 = UOp(Ops.LOAD, dtypes.long, (d0.index(UOp.const(dtypes.int, 0)),)).cast(dtypes.index)
idx = l0 * 600
valid = (l0<-1).ne(True)&(l0<3000)
l1 = UOp(Ops.LOAD, dtypes.long, (d1.index(idx, valid),))
l1 = UOp(Ops.LOAD, dtypes.long, (d1.index(idx.valid(valid)),))
uops = to_uops_list([l1])
for u in uops:
if u.op is Ops.INDEX: self.assertEqual(u.src[1].dtype, dtypes.int)

View File

@@ -8,13 +8,13 @@ from tinygrad.helpers import Context
def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, (
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx, valid),
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx.valid(valid)),
UOp.const(dtypes.float, 0.0)
))
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.int.vec(2), idx), valid),
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid)),
UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
))
@@ -269,7 +269,6 @@ class TestImageSimplification(unittest.TestCase):
load = get_load_image_uop(shape, (gidx1<5), (gidx0, gidx1+5))
self.check(load, None, "gidx0", "(gidx1+5)")
@unittest.skip("this should be constructed with an invalid gate")
def test_valid_empty_set(self):
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)

View File

@@ -95,7 +95,7 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
# lower the index dtype to a concrete int
ret.append(RewriteStep(pm_lower_index_dtype+load_store_indexing, lambda _: opts.device, name="lower all index dtypes"))
ret.append(RewriteStep(load_store_indexing+pm_lower_index_dtype, lambda _: opts.device, name="lower all index dtypes"))
# optional pre matcher
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))

View File

@@ -12,7 +12,7 @@ from tinygrad.renderer import Renderer
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.index(UOp.invalid())
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid))
# wait for it to be image indexed before running simplification
if start_idx.dtype.count != 2: return None
@@ -43,7 +43,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
if not drop_stmt and idx is start_idx: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
return buf.index(idx, new_valid)
return buf.index(idx.valid(new_valid) if new_valid is not None else idx)
def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
if store_gate not in [gate.src[0] for gate in val.toposort() if gate.op is Ops.IF]: return None
@@ -52,14 +52,9 @@ def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp,
load_store_indexing = PatternMatcher([
# image load valid idx simplification
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
# lower turn the invalid into a gate, must come before index dtype lowering
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate,),), lambda buf,x,cond,i: buf.index(x, cond)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate)), lambda buf,x,i,cond: simplify_valid_load(buf, x, cond)),
# drop true gate
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
# remove hanging cast
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
# delete_redundant_gates (after expand)
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates),
@@ -67,21 +62,21 @@ load_store_indexing = PatternMatcher([
# ***** load/store grouping *****
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
def expand_index(buf:UOp, vec:UOp):
if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx()
# generate the individual indexes
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i)) for i in range(vec.dtype.count)]),
symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
# extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
for i in range(vec.dtype.count):
idx: Any = midx.src[i].src[1]
idx: Any = midx.src[i].src[1].get_idx()
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
root_src = (midx.src[i].src[1].get_valid(), root_src)
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
# then rewrite everything we can into groups
@@ -124,8 +119,6 @@ def gep_on_store(gep:UOp, st:UOp, sto:UOp):
load_store_folding = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index),
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"),
UPat.var("mask"))), expand_index),
# GEP after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
@@ -165,7 +158,8 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
lengths.append(1) # worst case, it's not folded
# filter fold lengths that don't divide
if must_divide: lengths = [x for x in lengths if idx.src[1].divides(x) is not None]
offset, mask = idx.src[1].get_idx(), idx.src[1].get_valid()
if must_divide: lengths = [x for x in lengths if offset.divides(x) is not None]
# split based on the fold lengths
global_offset = 0
@@ -174,7 +168,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
# with 1 at the end of the lengths list, this will always hit
for fold_length in lengths:
if global_offset+fold_length > sz: continue
lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None)
lidx = buf.index((offset + global_offset).valid(mask))
if fold_length > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(fold_length).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
@@ -190,19 +184,20 @@ def image_fixup(ls:UOp):
if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType):
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
idx = ls.src[0].src[0]
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()
oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % image_dtype.shape[1], (x // (4*image_dtype.shape[1]))))
idx = idx.replace(src=(idx.src[0], oidx.valid(valid)))
return ls.replace(src=(idx,)+ls.src[1:])
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].dtype != dtypes.int.vec(2):
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].get_idx().dtype != dtypes.index.vec(2):
assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
idx = ls.src[0]
id4 = idx.src[1] % 4
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()
oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % image_dtype.shape[1], (x // (4*image_dtype.shape[1]))))
idx = idx.replace(src=(idx.src[0], oidx.valid(valid)))
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
return functools.reduce(lambda ret, i: (x % 4).ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
return None
@@ -229,6 +224,7 @@ def no_vectorized_wmma(wmma:UOp):
def no_vectorized_alu(alu:UOp):
if alu.dtype.vcount == 1: return None
if alu.op is Ops.WHERE and alu.src[2].arg is Invalid: return None # image load/store has cond.where(idx.vec(2), Invalid) as the index
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
return UOp(Ops.VECTORIZE, alu.dtype, alus)
@@ -238,7 +234,7 @@ def no_vectorized_buf(buf:UOp):
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
cnt = cast.dtype.count
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt))))
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))))
devectorize = PatternMatcher([
# no ALU on vectorized dtypes

View File

@@ -1,5 +1,5 @@
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start
from tinygrad.uop.symbolic import symbolic_flat, sym
from tinygrad.uop.symbolic import symbolic_flat, sym, invalid_pat
from tinygrad.helpers import partition
from tinygrad.dtype import dtypes
@@ -74,12 +74,12 @@ pm_reduce_collapse = PatternMatcher([
lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
# WHERE on LOAD (works on max too)
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True),
lambda buf,idx,gate: buf.index(idx, gate).load()),
lambda buf,idx,gate: buf.index(idx.valid(gate)).load()),
(UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
lambda buf,idx,gate: buf.index(idx.valid(gate.logical_not())).load()),
# INDEX on RANGE / gated RANGE
(UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())),
lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))),
(UPat.var("buf").index(UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted()).where(UPat.var("expr"), invalid_pat)),
lambda buf,r,idx,expr,i: buf.index(expr.substitute({r:idx.cast(r.dtype)}).valid((idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0])))),
# AND on WHERE
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),

View File

@@ -328,12 +328,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
@staticmethod
def invalid(): return UOp(Ops.CONST, dtypes.index, src=(), arg=Invalid)
def invalid(count=1): return UOp(Ops.CONST, dtypes.index.vec(count), src=(), arg=Invalid)
def valid(self, cond): return cond.where(self, UOp.invalid(self.dtype.count))
def get_idx(self) -> UOp:
assert self.dtype is dtypes.index, "Can only call get_idx on index dtype"
assert self.dtype.scalar() is dtypes.index, "Can only call get_idx on index dtype"
return self.src[1] if self.op is Ops.WHERE and self.src[2].arg is Invalid else self
def get_valid(self) -> UOp:
assert self.dtype is dtypes.index, "Can only call get_valid on index dtype"
assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype"
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def contiguous(self, *args, **kwargs): return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
@@ -1067,7 +1068,8 @@ pm_lower_index_dtype = PatternMatcher([
# comparison ops might now have different dtypes in their sources
(UPat(GroupOp.Comparison, name="u", src=(UPat.var("x",dtypes.ints), UPat.var("y", dtypes.ints))), lambda u,x,y:
x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)).alu(u.op, y.cast(dt)) if x.dtype!=y.dtype else None),
(UPat(Ops.WHERE, dtype=dtypes.index, src=(UPat.var("cond"), UPat.var("x"), UPat.var("y")), name="u"), lambda cond,u,x,y:
(UPat(Ops.WHERE, dtypes.index, src=(UPat(), UPat.var("x"), UPat(Ops.CONST, arg=Invalid)), name="u"), lambda u,x: u.replace(dtype=x.dtype)),
(UPat(Ops.WHERE, dtypes.index, src=(UPat.var("cond"), UPat.var("x"), UPat.var("y"))), lambda cond,x,y:
cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt))),
(UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=select_dtype(u))),
(UPat((Ops.RANGE,), dtype=dtypes.index, src=(UPat.var("end")), name="r"), lambda ctx,r,end:
@@ -1079,6 +1081,11 @@ pm_lower_index_dtype = PatternMatcher([
else dtypes.long)).vec(u.dtype.count),src=tuple(x.cast(dt) for x in u.src))),
(UPat((Ops.SPECIAL,Ops.DEFINE_VAR), dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int)),
(UPat((Ops.BIND), dtypes.index, name="u"), lambda u: u.replace(dtype=u.src[0].dtype)),
# lower Invalid
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond)),
# remove hanging cast
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
])
def index_to_concrete_int(u:UOp): return graph_rewrite(u, pm_lower_index_dtype)

View File

@@ -22,8 +22,8 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
def convert(v:ConstType): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)))
invalid_pat = UPat.const(dtypes.index, Invalid).named("i")
invalid_gate = UPat.var("cond").where(UPat.var("x",dtype=dtypes.index), invalid_pat)
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
propagate_invalid = PatternMatcher([
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0