From 2da02f1ae1b9dbcfd8652795d2bffd138617e2bb Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 30 Oct 2025 10:42:19 +0800 Subject: [PATCH] add loads at the end (#12988) * add loads at the end * simpler * late load * tests passing * fix matvec * spec test passes * fix where on load * fix abs2 * fix more tests --- docs/abstractions2.py | 4 +- test/test_dtype_alu.py | 1 + test/test_linearizer_dumb.py | 4 +- test/test_linearizer_failures.py | 4 +- test/test_profiler.py | 1 + test/test_renderer_failures.py | 2 +- test/test_schedule.py | 4 +- test/test_uop_graph.py | 84 ++++++++++++------------ test/test_uops.py | 17 +++-- test/test_uops_stats.py | 8 +-- test/unit/test_simplify_valid_idx.py | 4 +- test/unit/test_transcendental_helpers.py | 2 +- tinygrad/codegen/__init__.py | 7 +- tinygrad/codegen/gpudims.py | 2 +- tinygrad/codegen/late/devectorizer.py | 31 ++++++--- tinygrad/codegen/opt/heuristic.py | 4 +- tinygrad/codegen/simplify.py | 4 +- tinygrad/schedule/rangeify.py | 6 +- tinygrad/uop/ops.py | 18 +++-- tinygrad/uop/spec.py | 1 + tinygrad/uop/symbolic.py | 13 ++-- 21 files changed, 120 insertions(+), 101 deletions(-) diff --git a/docs/abstractions2.py b/docs/abstractions2.py index 708933118c..c1d13a86cf 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -53,9 +53,7 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc idx = UOp.const(dtypes.index, 0) buf_1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1) buf_2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2) -ld_1 = UOp(Ops.LOAD, dtypes.int32, (buf_1.index(idx),)) -ld_2 = UOp(Ops.LOAD, dtypes.int32, (buf_2.index(idx),)) -alu = ld_1 + ld_2 +alu = buf_1.index(idx) + buf_2.index(idx) output_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) st_0 = UOp(Ops.STORE, dtypes.void, (output_buf.index(idx), alu)) s = UOp(Ops.SINK, dtypes.void, (st_0,)) diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index 3f51c28c3c..4584dad4b3 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -194,6 +194,7 @@ class TestDTypeALU(unittest.TestCase): strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32, ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations)) @unittest.skipIf(Device.DEFAULT == "PYTHON", "TODO: fix cast inf to int32 in PYTHON") + @unittest.skip("broken on Mac") def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32) @unittest.skip("broken. TODO: fix it") diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index d14d3a6ae3..bc550ce812 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -16,12 +16,12 @@ class TestLinearizerFailure(unittest.TestCase): c2 = UOp.range(UOp.const(dtypes.index, 784), 1, AxisType.GLOBAL) c3 = UOp.range(UOp.const(dtypes.index, 10), 3, AxisType.GLOBAL) c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=()) - c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True))).load() + c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True))) c6 = UOp.range(UOp.const(dtypes.index, 6000), 1004, AxisType.REDUCE) c7 = UOp.range(UOp.const(dtypes.index, 3750), 2006, AxisType.REDUCE) c8 = UOp.range(UOp.const(dtypes.index, 16), 2007, AxisType.GROUP_REDUCE) c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(47040000), arg=2, src=()) - c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True))).load() + c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True))) c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.index, 6000))+c6)+((c7*UOp.const(dtypes.index, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.index, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD) c12 = c0.index((((c1*UOp.const(dtypes.index, 7840))+(c2*UOp.const(dtypes.index, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11).end(c1, c2, c3) ast = c12.sink(arg=KernelInfo(name='test', axis_types=(), dont_use_locals=False, applied_opts=(Opt(op=OptOps.GROUP, axis=1, arg=16),), opts_to_apply=None)) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 7917fa04d5..e5c1521b91 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -12,9 +12,9 @@ class TestLinearizerFailures(unittest.TestCase): c3 = ((c1*UOp.const(dtypes.index, 32))+c2) c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(163840), arg=1, src=()) c5 = UOp.range(UOp.const(dtypes.index, 2560), 0, AxisType.REDUCE) - c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920)))).load() + c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920)))) c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2, src=()) - c8 = c7.index(c3).load() + c8 = c7.index(c3) c9 = ((((c6+(c8*UOp.const(dtypes.float, -1.0)))*(c6+(c8*UOp.const(dtypes.float, -1.0)))).reduce(c5, arg=Ops.ADD)*UOp.const(dtypes.float, 0.000390625))+UOp.const(dtypes.float, 1e-05)).sqrt().reciprocal() c10 = c0.index(c3).store(c9).end(c1, c2) ast = c10.sink() diff --git a/test/test_profiler.py b/test/test_profiler.py index 83a547aa40..2836aa4432 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -199,6 +199,7 @@ class TestProfiler(unittest.TestCase): #self.assertLess(e1.st, e2.st) #self.assertGreater(e1.en-e1.st, e2.en-e2.st) + @unittest.skipIf(not CI, "this test is flaky locally") @unittest.skipUnless(Device[Device.DEFAULT].graph is not None, "graph support required") def test_graph(self): from test.test_graph import helper_alloc_rawbuffer, helper_exec_op, helper_test_graphs diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index 8efc7006dd..4baebae6b7 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -34,7 +34,7 @@ def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp): a = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0) b = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1) idx = UOp.const(dtypes.int, 0) - ld = UOp(Ops.LOAD, dtype, (b.index(idx),)) + ld = b.index(idx) alu = ld.alu(alu_op, *alu_src_uops) store = UOp.store(a.index(idx), alu) sink = UOp(Ops.SINK, dtypes.void, (store,)) diff --git a/test/test_schedule.py b/test/test_schedule.py index 32a90db6e2..856a801062 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2163,8 +2163,8 @@ class TestCopyFolding(unittest.TestCase): self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) def test_permute_on_disk_contiguous(self): - with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer()) - a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}") + with open(temp('dt_arange_4_permute_contig'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer()) + a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute_contig')}") b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU") b.realize() self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 704f17c40e..16aba5c44a 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -376,7 +376,7 @@ class TestUOpGraph(unittest.TestCase): d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0) d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1) idx = UOp.const(dtypes.int, 0) - ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),)) + ld = d1.index(idx) alu = (ld<1).cast(dtypes.bool) out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu)) uops = to_uops_list([out]) @@ -386,7 +386,7 @@ class TestUOpGraph(unittest.TestCase): d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0) d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1) idx = UOp.const(dtypes.int, 0) - ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),)) + ld = d1.index(idx) alu = ld.cast(dtypes.float).cast(dtypes.float) out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu)) uops = to_uops_list([out]) @@ -408,7 +408,7 @@ class TestUOpGraph(unittest.TestCase): def test_bitcast_to_same_dtype_fold(self): for dt in dtypes.ints + dtypes.floats + (dtypes.bool,): d0 = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), arg=0) - v = UOp(Ops.LOAD, dt, (d0.index(UOp.const(dtypes.int, 0)),)) + v = d0.index(UOp.const(dtypes.int, 0)) uops = to_uops_list([v.bitcast(dt)]) self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}") @@ -420,7 +420,7 @@ class TestUOpGraph(unittest.TestCase): def test_where_on_gated_load_fold(self): ridx0 = UOp.range(100, 0) d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0) - ld = d0.index(ridx0.valid(ridx0<50)).load() + ld = d0.index(ridx0.valid(ridx0<50)) w = (ridx0<50).where(ld, 5) uops = to_uops_list([w]) for u in uops: @@ -430,7 +430,7 @@ class TestUOpGraph(unittest.TestCase): def test_where_on_gated_load_folds_swapped_branches(self): ridx0 = UOp.range(100, 0) d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0) - ld = d0.index(ridx0.valid((ridx0<50).logical_not())).load() + ld = d0.index(ridx0.valid((ridx0<50).logical_not())) w = (ridx0<50).where(5, ld) uops = to_uops_list([w]) for u in uops: @@ -441,7 +441,7 @@ class TestUOpGraph(unittest.TestCase): ridx0 = UOp.range(100, 0) d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) gate_idx = ridx0.valid((ridx0<50)) - ld = d0.index(gate_idx).load().cast(dtypes.float) + ld = d0.index(gate_idx).cast(dtypes.float) w = (ridx0<50).where(ld, 5.0) uops = to_uops_list([w]) for u in uops: @@ -467,11 +467,11 @@ class TestUOpGraph(unittest.TestCase): c1 = UOp.range(UOp.const(dtypes.index, 512), 1, AxisType.LOOP) c2 = UOp.range(UOp.const(dtypes.index, 250), 2, AxisType.LOOP) c3 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=()) - c4 = c3.index(c1).load() + c4 = c3.index(c1) c5 = UOp.range(UOp.const(dtypes.index, 240), 0, AxisType.REDUCE) c6 = ((c2*UOp.const(dtypes.index, 240))+c5) c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(60000), arg=2, src=()) - c8 = c7.index(c6).load() + c8 = c7.index(c6) c9 = ((c4<0).where((c4+60000), c4)!=c6.cast(dtypes.int)).where(0, c8.cast(dtypes.uint).cast(dtypes.uchar)).reduce(c5, arg=Ops.ADD) c10 = c0.index(((c1*UOp.const(dtypes.index, 250))+c2)).store(c9).end(c1, c2) uops = to_uops_list([c10]) @@ -481,25 +481,25 @@ class TestUOpGraph(unittest.TestCase): def test_in_out_of_bounds_access(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0)),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0), ptr=True),)) to_uops_list([ld0]) - ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15)),)) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15), ptr=True),)) to_uops_list([ld1]) - ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7)),)) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7), ptr=True),)) to_uops_list([ld1]) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42)),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42), ptr=True),)) with self.assertRaises(RuntimeError): to_uops_list([ld0]) def test_in_out_of_bounds_access_symbolic(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10)),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10), ptr=True),)) to_uops_list([ld0]) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15)),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15), ptr=True),)) to_uops_list([ld0]) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20)),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20), ptr=True),)) with self.assertRaises(RuntimeError): to_uops_list([ld0]) def test_in_out_of_bounds_access_gated_store(self): @@ -531,7 +531,7 @@ class TestUOpGraph(unittest.TestCase): if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier)) # Load from local memory (after the IF/barrier) - local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx), if_barrier)) + local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier)) # Store to global memory global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load)) @@ -542,18 +542,18 @@ class TestUOpGraph(unittest.TestCase): ridx = UOp.range(20, 0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16))),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16)), ptr=True),)) to_uops_list([ld0]) glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0) ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),)) i = (ldfloat+3.14).cast(dtypes.int) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16)), ptr=True),)) def test_load_cast_to_bool(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0) ridx = UOp.range(20, 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not())),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not()), ptr=True),)) to_uops_list([ld0]) @unittest.skip("Bool load is not supported yet") @@ -562,36 +562,36 @@ class TestUOpGraph(unittest.TestCase): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0) ridx = UOp.range(20, 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),))) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True))) to_uops_list([ld0]) def test_out_of_bounds_off_by_one_access(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16)),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16), ptr=True),)) with self.assertRaises(RuntimeError): to_uops_list([ld0]) def test_in_out_bounds_access_with_mask(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) gidx0 = UOp.range(42, 0, AxisType.GLOBAL) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5=0)&(ld0<32))),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8), ptr=True),)).cast(dtypes.index) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32)), ptr=True),)) to_uops_list([ld1]) - ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64))),)) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64)), ptr=True),)) with self.assertRaises(RuntimeError): to_uops_list([ld1]) def test_bounds_with_loaded_bool(self): @@ -611,8 +611,8 @@ class TestUOpGraph(unittest.TestCase): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0) glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 0) gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0") - ld0 = glbl0.index(gidx0).load() - ld1 = glbl1.index(gidx0.valid(ld0)).load() + ld0 = glbl0.index(gidx0, ptr=True).load() + ld1 = glbl1.index(gidx0.valid(ld0), ptr=True).load() with self.assertRaises(RuntimeError): to_uops_list([ld1]) def test_fold_gated_load(self): @@ -620,38 +620,38 @@ class TestUOpGraph(unittest.TestCase): glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2) idx = UOp.const(dtypes.int, 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),)) - ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx.valid(UOp.const(dtypes.bool, True))),)) + ld0 = glbl1.index(UOp.invalid()) + ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True))) uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))]) ld0 = uops[-1].src[-1] # the gate and invalid value are deleted from ld1 - self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int)) + self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int)) def test_fold_gated_load_local(self): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp") lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0") - st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int))) + st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load())) barrier = UOp(Ops.BARRIER, dtypes.void, (st, )) - ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),)) - ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True))),)) + ld0 = smem.after(barrier).index(UOp.invalid()) + ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True))) uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))]) ld0 = uops[-1].src[-1] # the gate and invalid value are deleted from ld1 - self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2)) + self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True)) def test_fold_gated_store(self): glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) idx0 = UOp.const(dtypes.int, 0) idx1 = UOp.const(dtypes.int, 0) val = UOp.const(dtypes.int, 42) - st0 = glbl.index(UOp.invalid()).store(val) - st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True))).store(val) + st0 = glbl.index(UOp.invalid(), ptr=True).store(val) + st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True)), ptr=True).store(val) uops = to_uops_list([st0, st1]) # only the second store happens self.assertEqual(len(uops), 5) - self.assertEqual(uops[-1], glbl.index(idx1).store(val)) + self.assertEqual(uops[-1], glbl.index(idx1, ptr=True).store(val)) @unittest.skip("this is a uop type error") def test_asserts_bad_gate(self): diff --git a/test/test_uops.py b/test/test_uops.py index 375c166124..ab57157da0 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -39,9 +39,9 @@ def _test_single_value(vals, op, dts): output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1] buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)] - loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts)) + loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True)]) for i, dtype in enumerate(dts)) alu = uop(uops, op, output_dtype, loads) - out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu)) + out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)] prg = _uops_to_prg([out]) @@ -338,7 +338,7 @@ class TestLocalAccess(unittest.TestCase): smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem') st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0))) barr = uop(uops, Ops.BARRIER, dtypes.void, (st,)) - sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),)) + sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True),)) self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42) # NOTE: webgpu specific, since only webgpu performs bitpacking @@ -348,7 +348,7 @@ class TestLocalAccess(unittest.TestCase): smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem') st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42))) barr = uop(uops, Ops.BARRIER, dtypes.void, (st,)) - sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),)) + sres = smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)) self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42) # NOTE: webgpu specific, since only webgpu performs bitpacking @@ -382,7 +382,7 @@ class TestAssembly(unittest.TestCase): g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) c1 = UOp(Ops.CONST, dtypes.int, (), 2) c2 = UOp(Ops.CONST, dtypes.int, (), 3) - l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),)) + l1 = g1.index(c1) a1 = UOp(Ops.MUL, dtypes.int, (l1, c1)) a2 = UOp(Ops.MUL, dtypes.int, (l1, c2)) uops = to_uops_list([a1,a2], ren=Device[Device.DEFAULT].renderer) @@ -395,7 +395,7 @@ class TestAssembly(unittest.TestCase): for dt in (dtypes.int32, dtypes.uint32): g = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), (), 0) c = UOp(Ops.CONST, dt, (), 2) - l = UOp(Ops.LOAD, dt, (g.index(c),)) + l = g.index(c) a = UOp(Ops.IDIV, dt, (l, c)) uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render(uops) @@ -406,7 +406,7 @@ class TestAssembly(unittest.TestCase): def test_fast_idiv_and_mod(self): g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0) c = UOp(Ops.CONST, dtypes.uint, (), 3) - l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),)) + l = g.index(c) a = UOp(Ops.IDIV, dtypes.uint, (l, c)) uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render(uops) @@ -458,8 +458,7 @@ class TestAssembly(unittest.TestCase): def test_use_cmpeq(self): g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0) c = UOp(Ops.CONST, dtypes.uint, (), 7) - l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),)) - comp = l.ne(c).ne(True) + comp = g.index(c).ne(c).ne(True) uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render(uops) ops = [x.op for x in uops] diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 845ab8b325..1aff484ab6 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -141,8 +141,8 @@ class TestUOpsStats(unittest.TestCase): globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1) o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2) - u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),)) - u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),)) + u1 = globl.index(o1) + u2 = globl.index(o2) u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3) u4 = UOp(Ops.MUL, dtypes.int, (u1,u2)) u5 = UOp(Ops.ADD, dtypes.int, (u4,u3)) @@ -151,8 +151,8 @@ class TestUOpsStats(unittest.TestCase): globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1) o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2) - u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),)) - u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),)) + u1 = globl.index(o1) + u2 = globl.index(o2) u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3) u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3)) uops_fma = full_rewrite(u4.sink()) diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index ccea0deb21..c4c64cd669 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -9,13 +9,13 @@ from test.unit.test_uop_symbolic import check_uop_against_string def get_gated_load_uop(valid:UOp, idx:UOp): return UOp(Ops.LOAD, dtypes.float, ( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx.valid(valid)), + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx.valid(valid), ptr=True), UOp.const(dtypes.float, 0.0) )) def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]): return UOp(Ops.LOAD, dtypes.float.vec(4), ( - UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid)), + UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid), ptr=True), UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4) )) diff --git a/test/unit/test_transcendental_helpers.py b/test/unit/test_transcendental_helpers.py index 4c697903a1..6f3b9a324c 100644 --- a/test/unit/test_transcendental_helpers.py +++ b/test/unit/test_transcendental_helpers.py @@ -11,7 +11,7 @@ class TestTranscendentalFunctions(unittest.TestCase): # TODO: Test constant input when constant folding is fixed (or maybe test both variants) # Load input value from a buffer to prevent constant folding input_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.double.ptr(), arg=1, src=()) - loaded_value = UOp.load(input_buf.index(UOp.const(dtypes.int, 0)), dtype=dtypes.double) + loaded_value = input_buf.index(UOp.const(dtypes.int, 0)) def eval_payne_hanek_reduction(v:float) -> tuple[float, int]: return tuple(eval_uop(u, [(dtypes.float64, [v])]) for u in payne_hanek_reduction(loaded_value)) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 1ebd15f7a5..dbd017716b 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -12,7 +12,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, p from tinygrad.uop.decompositions import get_late_rewrite_patterns from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ - ReduceContext, correct_load_store, pm_render + ReduceContext, correct_load_store, pm_render, pm_add_loads from tinygrad.codegen.opt.postrange import apply_opts from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen @@ -59,6 +59,11 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - # add gpu dims (late). this works after devectorize, but it's faster here sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims") + # **** optimizations are done, now we lower to actual code **** + + # add loads + sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)") + # devectorize (TODO: does this need opts?) if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index 07c758499a..e661b45650 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -80,7 +80,7 @@ def add_gpudims(ctx:Renderer, s:UOp): subs = {} for r in s_topo: # look for local INDEXes that are not used in the GLOBAL store, then add them as an INVALID - if r.op is Ops.STORE and r.src[0].ptrdtype.addrspace == AddrSpace.GLOBAL: + if r.op is Ops.STORE and r.src[0].src[0].ptrdtype.addrspace == AddrSpace.GLOBAL: idx = r.src[0] missing_locals = [all_ranges[rng] for rng in local_dims if all_ranges[rng] not in idx.ranges] if len(missing_locals): diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index b7d9c81bdd..2ee97e37bd 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -2,7 +2,7 @@ from typing import Any, cast import functools, operator, itertools from collections import defaultdict from dataclasses import dataclass -from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid +from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat, invalid_gate from tinygrad.helpers import getenv, flatten, AMX, prod @@ -12,7 +12,7 @@ from tinygrad.renderer import Renderer def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: idx = uop_given_valid(valid, start_idx) - if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid)) + if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True) # wait for it to be image indexed before running simplification if start_idx.dtype.count != 2: return None @@ -43,7 +43,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: if not drop_stmt and idx is start_idx: return None new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None - return buf.index(idx.valid(new_valid) if new_valid is not None else idx) + return buf.index(idx.valid(new_valid) if new_valid is not None else idx, ptr=True) load_store_indexing = PatternMatcher([ @@ -52,7 +52,7 @@ load_store_indexing = PatternMatcher([ # simplify away long after index has been lowered (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x", dtypes.long), UPat.var("c", dtypes.bool))), lambda buf,x,c: simplify_valid_load(buf, x, c)), # drop true gate - (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)), + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x, ptr=True)), ]) # ***** load/store grouping ***** @@ -60,7 +60,7 @@ load_store_indexing = PatternMatcher([ def expand_index(buf:UOp, vec:UOp): if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx() # generate the individual indexes - midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i)) for i in range(vec.dtype.count)]), + midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), ptr=True) for i in range(vec.dtype.count)]), symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}") # extract all the relevant offsets offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict) @@ -163,7 +163,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): # with 1 at the end of the lengths list, this will always hit for fold_length in lengths: if global_offset+fold_length > sz: continue - lidx = buf.index((offset + global_offset).valid(mask)) + lidx = buf.index((offset + global_offset).valid(mask), ptr=True) if fold_length > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(fold_length).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace)) if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:])) else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length))) @@ -229,7 +229,7 @@ def no_vectorized_buf(buf:UOp): def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp): cnt = cast.dtype.count assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}" - return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt)))) + return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))), ptr=True) def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp): cnt = cast.dtype.count @@ -237,7 +237,7 @@ def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp): input_gep = bcast.arg if bcast.op is Ops.GEP else ([0]*precnt) gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)])) sum_arg = tuple(flatten([[i+y for y in input_gep] for i in range(cnt)])) - return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg)) + return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg), ptr=True) devectorize_buf_and_index = PatternMatcher([ (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf), @@ -302,11 +302,11 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)) acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \ acc.index(UOp.const(dtypes.int, 0)).store(identity) - lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0)).load()] + lst # put acc as the first element + lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0))] + lst # put acc as the first element ctx.acc_num += 1 ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) if len(reduce_range) == 0: return ret - return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0)).load() + return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0)) pm_reduce = PatternMatcher([ # REDUCE -> DEFINE_ACC+ASSIGN @@ -315,3 +315,14 @@ pm_reduce = PatternMatcher([ (UPat(Ops.WMMA, name="wmma") + UPat.var("add"), lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), ])+sym + +# add loads + +pm_add_loads = PatternMatcher([ + # add loads to non ptr index + (UPat(Ops.INDEX, name="idx"), lambda idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else + idx.replace(dtype=idx.src[0].dtype).load(dtype=idx.dtype.base)), + # remove loads from stores + (UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])), +]) + diff --git a/tinygrad/codegen/opt/heuristic.py b/tinygrad/codegen/opt/heuristic.py index 12a1248823..639b089210 100644 --- a/tinygrad/codegen/opt/heuristic.py +++ b/tinygrad/codegen/opt/heuristic.py @@ -64,8 +64,8 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler: MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) if k.ren.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.ren.has_shared and \ - (mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD: - idx0, idx1 = mulop.src[0].src[0].src[1].get_idx(), mulop.src[1].src[0].src[1].get_idx() + (mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.INDEX and mulop.src[1].op is Ops.INDEX: + idx0, idx1 = mulop.src[0].src[1].get_idx(), mulop.src[1].src[1].get_idx() if k.ranges_of(AxisType.REDUCE): first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0] if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges): diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index 5562a32603..13b67606d1 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -55,7 +55,7 @@ def do_substitute(ctx, x: UOp): return ret def dont_sub_ranges_for_image(ctx, x:UOp): - if isinstance(x.src[0].dtype, ImageDType): + if isinstance(x.src[0].src[0].dtype, ImageDType): for s in x.src[0].ranges: ctx[s] = None pm_split_ranges = PatternMatcher([ @@ -129,7 +129,7 @@ def reduce_load_collapse(red:UOp): return reduce_collapse(red, pm=pm_reduce_load # remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range pm_reduce_simplify = pm_reduce_unparented + PatternMatcher([(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),]) # remove REDUCE on load, comes from indexing a tensor with another tensor -def no_load(u:UOp) -> bool: return not any(x.op is Ops.LOAD for x in u.backward_slice_with_self) +def no_load(u:UOp) -> bool: return not any(x.op is Ops.INDEX for x in u.backward_slice_with_self) pm_load_collapse = PatternMatcher([ (UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_load_collapse), # we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 86dd580dd7..da23459778 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -435,9 +435,9 @@ rangeify_codegen = PatternMatcher([ # add loads to non ptr indexes # TODO: this can be moved into codegen? - (UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg")) - .f(Ops.INDEX, name="idx", allow_any_len=True), - lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), + #(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg")) + # .f(Ops.INDEX, name="idx", allow_any_len=True), + # lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), # fix broadcast dtype (UPat(Ops.AFTER, name="a").broadcast(name="b"), lambda a,b: a.broadcast(len(b.src))), diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index ec6bc8de96..1a16ad6c2f 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -339,8 +339,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0] return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None])) def detach(self): return UOp(Ops.DETACH, self.dtype, (self,)) - def index(self, *srcs:UOp|None, **kwargs): - return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) + def index(self, *srcs:UOp|None, ptr=False, **kwargs): + return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) def __getitem__(self, *idx): return self.index(*idx) def const_like(self, b:ConstLike): # constants can optionally have a DEVICE source @@ -743,6 +743,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): ret = graph_rewrite(self.simplify() if simplify else self, renderer if pm is None else pm) return ret.arg if ret.op is Ops.NOOP else str(ret) + def pyrender(self): return pyrender(self) + @dataclass(frozen=True) class KernelInfo: name: str = "test" # name of the kernel @@ -1199,10 +1201,11 @@ pm_lower_index_dtype = PatternMatcher([ (UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.index), UPat.cvar("val").cast(dtypes.index))), lambda var,val: var.bind(val).cast(dtypes.index)), (UPat(Ops.CAST, src=(UPat(name="x").cast(dtypes.index),), name="c"), lambda x,c: x.cast(c.dtype)), # lower Invalid - (UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond)), + (UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)), # remove hanging casts - (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx)), - (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)), + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)), + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), + lambda buf,idx,valid: buf.index(idx, valid, ptr=True)), (UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"), lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))), (UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"), @@ -1279,8 +1282,9 @@ pm_pyrender_extra = PatternMatcher([ "UOp.range("+', '.join([str(c.arg)] + [str(y) for y in x.arg])+ (f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"), # TODO: index shouldn't mismatch dtype - (UPat(Ops.INDEX, src=(UPat(), UPat()), name="x"), lambda ctx,x: - f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, dtype={x.dtype})" if x.src[0].dtype != x.dtype else None), + (UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x: + f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+(f"{ctx[x.src[2]]}, " if len(x.src) > 2 else "")+ + (f"dtype={x.dtype})" if x.src[0].dtype != x.dtype else "ptr=True)") if x.src[0].dtype.base != x.dtype else None), # TODO: fix forced_reshape (UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({render_marg(ctx,x)})" if x.src[0].shape == x.shape else None), (UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"), diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index d882f4206c..9b955113e5 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -133,6 +133,7 @@ shared_codegen_spec = PatternMatcher([ (UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), # LOAD(idx) / STORE(idx, val) / LOAD with alt value only exists in program_spec + # TODO: move LOAD to the program_spec (UPat().index(UPat()).or_casted().load(), lambda: True), (UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True), diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index d8ec88566a..a89536f10b 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -413,7 +413,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp: bounds[expr][int(is_upper)] = c # don't simplify any other gates, can lead to OOB, we substitute them back later - uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX})) + uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, dtype=u.dtype, arg=u) for u in uop.toposort() if u.op is Ops.INDEX})) # simplify uop given that valid is True all_candidates = [] @@ -479,21 +479,20 @@ def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None: return UOp.const(dtypes.bool, True).prod(*[c for c in cond.split_uop(Ops.AND) if c not in dropped_clauses]).where(x, i) pm_drop_and_clauses = PatternMatcher([(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), drop_and_clauses)]) -def where_on_load(l, c1, buf, x): +def where_on_load(c1, buf, x): c2 = x.get_valid() duplicate_clauses = [c for c in c1.split_uop(Ops.AND) if c in c2.split_uop(Ops.AND)] # we move the condition from the where to the load _as long as_ the condtition doesn't have some range that would place it inside of a new range # also no data dependent loads! moved_clauses = [c for c in c1.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in x.ranges for r in c.ranges) - and all(u in x.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.LOAD)] + and all(u in x.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.INDEX)] if not (removed:=moved_clauses+duplicate_clauses): return None # aditionally we can drop the clause on the where if it already exists in the load remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in c1.split_uop(Ops.AND) if c not in removed]) - return remaining_clause.where(UOp.load(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2)), *l.src[1:])), 0) + return remaining_clause.where(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2))), 0) pm_move_where_on_load = PatternMatcher([ - (UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l"), 0), where_on_load), - (UPat.var("c1").where(0, UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l")), - lambda l,c1,buf,x: where_on_load(l,c1.logical_not(),buf,x)), + (UPat.var("c1").where(UPat.var("buf").index(UPat.var("x")), 0), where_on_load), + (UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)), ]) pm_simplify_valid = PatternMatcher([