From e14b4fefa5ea73856ddeed91092f5bb5bbbef9ec Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 22 Jul 2025 21:00:50 -0700 Subject: [PATCH] ranges on store (#11334) * ranges on store * fix store spec * fix that * fix gates * fix tests * fix ptx --- test/test_linearizer.py | 24 ++++++++++++------------ test/unit/test_uop_symbolic.py | 4 ++-- tinygrad/codegen/devectorizer.py | 10 +++++----- tinygrad/codegen/lowerer.py | 2 +- tinygrad/renderer/ptx.py | 2 +- tinygrad/uop/symbolic.py | 5 +++-- 6 files changed, 24 insertions(+), 23 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index c3bc5c573e..4774df5fac 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -267,7 +267,7 @@ class TestLinearizer(unittest.TestCase): stores = [u for u in uops if u.op is Ops.STORE] assert len(accs) == 0 # it's removed now assert len(stores) == 1 - assert stores[0].src[-1].dtype == dtypes.float.vec(4) + assert stores[0].src[1].dtype == dtypes.float.vec(4) # NOTE: can reenable, it does work. it just makes BEAM slow @unittest.expectedFailure @@ -294,10 +294,10 @@ class TestLinearizer(unittest.TestCase): stores = [u for u in program.uops if u.op is Ops.STORE and u.dtype.addrspace != AddrSpace.REG] # the first store is to lds and can be upcasted - assert stores[0].src[-1].dtype == dtypes.float.vec(4) + assert stores[0].src[1].dtype == dtypes.float.vec(4) assert any(x.op is Ops.DEFINE_LOCAL for x in stores[0].toposort()) # the second store is to gds with no upcasts - assert stores[1].src[-1].dtype == dtypes.float + assert stores[1].src[1].dtype == dtypes.float assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].toposort()) def test_zero_fold(self): @@ -648,7 +648,7 @@ class TestLinearizer(unittest.TestCase): k = helper_linearizer_opt(out)[-1] uops = get_program(k.get_optimized_ast(), k.opts).uops # check that the float4 cast collapses - store_vals = [u.src[-1] for u in uops if u.op is Ops.STORE and u.dtype.addrspace != AddrSpace.REG] + store_vals = [u.src[1] for u in uops if u.op is Ops.STORE and u.dtype.addrspace != AddrSpace.REG] for val in store_vals: assert val.dtype == dtypes.float.vec(4) # and val.op is not Ops.VECTORIZE @@ -671,7 +671,7 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn((4,3,6,6)).realize() out = x.flip((0,1)).contiguous() k = helper_linearizer_opt(out)[-1] - store_val = [u.src[-1] for u in get_program(k.get_optimized_ast(), k.opts).uops if u.op is Ops.STORE][0] + store_val = [u.src[1] for u in get_program(k.get_optimized_ast(), k.opts).uops if u.op is Ops.STORE][0] assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not Ops.VECTORIZE @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @@ -690,7 +690,7 @@ class TestLinearizer(unittest.TestCase): barrier = [u for u in uops if u.op is Ops.BARRIER][0] # check that the float4 cast collapses for all stores for store in local_stores+global_stores: - assert store.src[-1].dtype.count > 1 # and store.src[2].op is not Ops.VECTORIZE + assert store.src[1].dtype.count > 1 # and store.src[2].op is not Ops.VECTORIZE # # check the children's vins # TODO: src ALU are not the same, should it? # assert barrier.src == tuple(local_stores) @@ -707,11 +707,11 @@ class TestLinearizer(unittest.TestCase): stores = [u for u in uops if u.op is Ops.STORE and u.dtype.addrspace != AddrSpace.REG] # the float4 value stores directly in lds and we skip upcast - self.assertEqual(stores[0].src[-1].dtype, dtypes.float.vec(4)) + self.assertEqual(stores[0].src[1].dtype, dtypes.float.vec(4)) #assert stores[0].src[-1].op is not Ops.VECTORIZE # the global store doesn't change - assert stores[1].src[-1].dtype == dtypes.float + assert stores[1].src[1].dtype == dtypes.float @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") @@ -730,7 +730,7 @@ class TestLinearizer(unittest.TestCase): ] k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1] out = [u for u in get_program(k.get_optimized_ast(), k.opts).uops if u.op is Ops.STORE][0] - assert out.src[-1].op is Ops.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(4) + assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype == dtypes.float.vec(4) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") @@ -748,18 +748,18 @@ class TestLinearizer(unittest.TestCase): Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)] k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1] out = [u for u in get_program(k.get_optimized_ast(), k.opts).uops if u.op is Ops.STORE][0] - assert out.src[-1].op is Ops.VECTORIZE and out.src[-1].dtype.count != 1 + assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype.count != 1 @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4") class TestFloat4(unittest.TestCase): @staticmethod def count_float4(uops: list[UOp], n=4): return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.float.vec(n)]), - len([uop for uop in uops if uop.op is Ops.STORE and uop.src[-1].dtype == dtypes.float.vec(n)])) + len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.float.vec(n)])) @staticmethod def count_half4(uops: list[UOp]): return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.half.vec(4)]), - len([uop for uop in uops if uop.op is Ops.STORE and uop.src[-1].dtype == dtypes.half.vec(4)])) + len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.half.vec(4)])) def test_float4_basic(self): a = Tensor.empty(2, 8).realize() diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 431c6305b4..8568d157f6 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -14,7 +14,7 @@ def render(self) -> tuple[str, ConstType, ConstType]: # NOTE: we need STORE so the ALU op has children glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink()) - rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1] + rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1] return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax def uconst(val): return UOp.const(dtypes.int, val) @@ -642,7 +642,7 @@ class TestSymbolic(unittest.TestCase): # TODO: copied from render, render does not support cast glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink()) - rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1] + rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1] self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half))) diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index 829616fe54..5750b82b6b 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -47,10 +47,10 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None return buf.index(idx, new_valid) -def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None: +def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None: if store_gate not in [gate.src[0] for gate in val.toposort() if gate.op is Ops.IF]: return None # remove the gate from the index - return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val) + return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val, *store.src[2:]) load_store_indexing = PatternMatcher([ # simplify valid @@ -61,7 +61,7 @@ load_store_indexing = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)), # delete_redundant_gates (after expand) (UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")), - UPat.var("val"))), delete_redundant_gates), + UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates), ]) # ***** load/store grouping ***** @@ -311,7 +311,7 @@ pm_render = PatternMatcher([ lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None), # gate any stores that aren't gated with ifs (UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True), - lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),)) if \ + lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \ len(store.src) <= 2 or store.src[2].op != Ops.IF else None), ]) @@ -340,7 +340,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): lst = [acc.load()] + lst # put acc as the first element ctx.acc_num += 1 ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) - return acc.store(ret).load() if len(reduce_range) != 0 else ret + return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret def no_vectorized_reduce(inp:UOp, red:UOp): if inp.dtype != red.dtype: diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 5e9280a7e1..a9e64f3cc7 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -57,7 +57,7 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp): # NOTE: only store the local reduceop in the threads that are actually doing the reduce for oidx, ridx in zip(ctx.idxs, ctx.ridxs): if oidx is not ridx: valid = valid * oidx.eq(0) - return buf.index(idx, valid).store(x.src[1]) + return buf.index(idx, valid).store(x.src[1], *[x for x in UOp.sink(idx, valid).toposort() if x.op is Ops.RANGE]) def lower_const(ctx:IndexContext, view:UOp, c:UOp): if all(x.mask is None for x in view.arg.views): return c diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index e5f44bc5bf..913035796c 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -54,7 +54,7 @@ ptx_matcher = PatternMatcher([ # move mask from INDEX to the load/store to enable pointer arithmetic (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("gate"))), UPat.var("alt"))), lambda buf,idx,gate,alt: UOp(Ops.LOAD, alt.dtype, (buf.index(idx), alt, gate))), - (UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat())), UPat.var("val"), UPat.var("gate"))), + (UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat())), UPat.var("val"), UPat.var("gate")), allow_any_len=True), lambda buf,idx,val,gate: UOp.store(buf.index(idx), val, gate)), # ptx shr and shl instructions require y to be uint (UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None), diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 4f9cde850a..4d5aca5117 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -451,8 +451,9 @@ sym = symbolic_flat+PatternMatcher([ (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)), # ** load/store folding ** (UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)), - (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))), - lambda index, gate, alt: UOp.store(index.src[0].index(index.src[1], gate), alt)), + (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), + UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"), + lambda index, gate, alt, store: UOp.store(index.src[0].index(index.src[1], gate), alt, *store.src[2:])), # fold gated LOAD/STORE (UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True (UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat(), UPat.const(dtypes.bool, False)).or_casted(),), allow_any_len=True, name="x"),