diff --git a/test/helpers.py b/test/helpers.py index 706297bc10..c40a0618b0 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -79,7 +79,7 @@ def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]: def eval_uop(uop:UOp): g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=()) - rw = full_graph_rewrite(UOp.store(g, UOp.const(dtypes.int, 0), uop).sink(), PythonRenderer) + rw = full_graph_rewrite(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer) prog = PythonProgram("run", PythonCompiler().compile(PythonRenderer().render("run", linearize_uop(rw)))) buf = PythonAllocator().alloc(uop.dtype.itemsize) prog(buf) diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index 77975d6665..d997cb3f37 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -33,9 +33,9 @@ class TestCStyleFailures(unittest.TestCase): a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) b = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) idx = UOp.const(dtypes.int, 0) - ld = UOp(Ops.LOAD, dtypes.int, (b, idx)) + ld = UOp(Ops.LOAD, dtypes.int, (b.index(idx),)) alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1)) - store = UOp.store(a, idx, alu) + store = UOp.store(a.index(idx), alu) sink = UOp(Ops.SINK, dtypes.void, (store,)) uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) # CLANG doesn't use the max function @@ -47,7 +47,7 @@ class TestPTXFailures(unittest.TestCase): def test_gated_store_with_alu(self): a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) - gated_alu_store = UOp(Ops.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu)) + gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1))) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] @@ -58,7 +58,7 @@ class TestPTXFailures(unittest.TestCase): gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) val = UOp.const(dtypes.int, 1) if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,)) - gated_alu_store = UOp(Ops.STORE, dtypes.void, (a, lidx0, val, if_uop)) + gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, if_uop), val)) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 8d56eb2d98..36455e9804 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -252,7 +252,7 @@ class TestUOpGraph(unittest.TestCase): idx = UOp.const(dtypes.int, 0) def _test_vec(geps, count=4): vec = UOp(Ops.VECTORIZE, dtypes.float.vec(count), geps) - out = UOp(Ops.STORE, dtypes.void, (d0, idx, vec)) + out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), vec)) uops = to_uops_list([out]) if DEBUG >= 4: from tinygrad import Device @@ -260,26 +260,26 @@ class TestUOpGraph(unittest.TestCase): return uops[-1].src[-1] # possible - val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1, idx)) + val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1.index(idx),)) xyzw = tuple(UOp(Ops.GEP, dtypes.float, (val,), (i,)) for i in range(4)) self.assertIs(_test_vec(xyzw).op, Ops.LOAD) # unaligned - val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1, idx)) + val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1.index(idx),)) wzyx = tuple(UOp(Ops.GEP, dtypes.float, (val,), (i,)) for i in reversed(range(4))) self.assertIs(_test_vec(wzyx).op, Ops.VECTORIZE) # different_size - val = UOp(Ops.LOAD, dtypes.float.vec(2), (d1, idx)) + val = UOp(Ops.LOAD, dtypes.float.vec(2), (d1.index(idx),)) xy = tuple(UOp(Ops.GEP, dtypes.float, (val, ), (i,)) for i in range(2)) self.assertIs(_test_vec(xy+xy).op, Ops.VECTORIZE) - val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1, idx)) + val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1.index(idx),)) xy = tuple(UOp(Ops.GEP, dtypes.float, (val, ), (i,)) for i in range(2)) self.assertIs(_test_vec(xy, count=2).op, Ops.VECTORIZE) # different vals - val1 = UOp(Ops.LOAD, dtypes.float.vec(2), (d1, idx)) - val2 = UOp(Ops.LOAD, dtypes.float.vec(2), (d2, idx)) + val1 = UOp(Ops.LOAD, dtypes.float.vec(2), (d1.index(idx),)) + val2 = UOp(Ops.LOAD, dtypes.float.vec(2), (d2.index(idx),)) xy1 = tuple(UOp(Ops.GEP, dtypes.float, (val1, ), (i,)) for i in range(2)) xy2 = tuple(UOp(Ops.GEP, dtypes.float, (val2, ), (i,)) for i in range(2)) self.assertIs(_test_vec(xy1+xy2).op, Ops.VECTORIZE) @@ -355,9 +355,9 @@ class TestUOpGraph(unittest.TestCase): d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0) d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1) idx = UOp.const(dtypes.int, 0) - ld = UOp(Ops.LOAD, dtypes.int, (d1, idx)) + ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),)) alu = ld.lt(1).cast(dtypes.bool) - out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu)) + out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu)) uops = to_uops_list([out]) self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0) @@ -365,9 +365,9 @@ class TestUOpGraph(unittest.TestCase): d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0) d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1) idx = UOp.const(dtypes.int, 0) - ld = UOp(Ops.LOAD, dtypes.int, (d1, idx)) + ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),)) alu = ld.cast(dtypes.float).cast(dtypes.float) - out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu)) + out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu)) uops = to_uops_list([out]) self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1) @@ -390,9 +390,9 @@ class TestUOpGraph(unittest.TestCase): glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2) idx = UOp.const(dtypes.int, 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False))) - ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True))) - uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, ld1+ld0))]) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(idx, UOp.const(dtypes.bool, False)),)) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx, UOp.const(dtypes.bool, True)),)) + uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))]) ld0 = uops[-1].src[-1] # the gate and invalid value are deleted from ld1 self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int)) @@ -401,11 +401,12 @@ class TestUOpGraph(unittest.TestCase): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(local=True), (), ("temp", 1)) lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16)) - st = UOp(Ops.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int))) + st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int))) barrier = UOp(Ops.BARRIER, dtypes.void, (st, )) - ld0 = UOp(Ops.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False), barrier)) - ld1 = UOp(Ops.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True), barrier)) - uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, lidx, ld1+ld0))]) + ld0 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+1, UOp.const(dtypes.bool, False)), barrier)) + ld1 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+2, UOp.const(dtypes.bool, True)), barrier)) + uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))]) + ld0 = uops[-1].src[-1] # the gate and invalid value are deleted from ld1 self.assertEqual(ld0.src[0], smem.index(lidx+2)) @@ -415,8 +416,8 @@ class TestUOpGraph(unittest.TestCase): idx0 = UOp.const(dtypes.int, 0) idx1 = UOp.const(dtypes.int, 0) val = UOp.const(dtypes.int, 42) - st0 = UOp(Ops.STORE, dtypes.void, (glbl, idx0, val, UOp.const(dtypes.bool, False))) - st1 = UOp(Ops.STORE, dtypes.void, (glbl, idx1, val, UOp.const(dtypes.bool, True))) + st0 = UOp(Ops.STORE, dtypes.void, (glbl.index(idx0, UOp.const(dtypes.bool, False)), val)) + st1 = UOp(Ops.STORE, dtypes.void, (glbl.index(idx1, UOp.const(dtypes.bool, True)), val)) uops = to_uops_list([st0, st1]) # only the second store happens self.assertEqual(len(uops), 5) @@ -437,7 +438,7 @@ class TestUOpGraph(unittest.TestCase): r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 0, False)) r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 1, False)) alu = UOp(Ops.ALU, dtypes.int, (r2, r1), BinaryOps.MUL) - store = UOp(Ops.STORE, dtypes.void, (glbl, alu, cf)) + store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf)) uops = to_uops_list([store]) ranges = [x for x in uops if x.op is Ops.RANGE] endranges = [x for x in uops if x.op is Ops.ENDRANGE] @@ -597,14 +598,15 @@ class TestExpander(unittest.TestCase): class TestLoadStoreFolder(unittest.TestCase): def test_simple_load_fold(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) - load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)] + load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i)),)) for i in range(4)] sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) + sink = float4_rewrite(sink.sink()) assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1 def test_two_load_fold(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) - load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)] + load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i)),)) for i in range(8)] sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink.sink()) assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 2 @@ -612,7 +614,7 @@ class TestLoadStoreFolder(unittest.TestCase): def test_simple_load_fold_gated(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) gate = UOp(Ops.DEFINE_VAR, dtypes.bool) - load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)] + load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate),)) for i in range(4)] sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink.sink()) assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1 @@ -623,14 +625,15 @@ class TestLoadStoreFolder(unittest.TestCase): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) gate = UOp.variable("g1", False, True, dtypes.bool) gate2 = UOp.variable("g2", False, True, dtypes.bool) - load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate if i == 0 else gate2)) for i in range(4)] + load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate if i == 0 else gate2), + UOp.const(dtypes.float, 0))) for i in range(4)] sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink.sink()) assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 3 def test_simple_store_fold(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) - load = [UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0))) for i in range(4)] + load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0))) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1 @@ -638,7 +641,7 @@ class TestLoadStoreFolder(unittest.TestCase): def test_simple_store_fold_gate(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) gate = UOp.variable("g1", False, True, dtypes.bool) - load = [UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)] + load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0), gate)) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1 @@ -650,7 +653,8 @@ class TestLoadStoreFolder(unittest.TestCase): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) gate = UOp.variable("g1", False, True, dtypes.bool) gate2 = UOp.variable("g2", False, True, dtypes.bool) - load = [UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] + load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate if i == 0 else gate2), + UOp.const(dtypes.float, i))) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 3 @@ -663,10 +667,10 @@ class TestIFUOps(unittest.TestCase): lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4)) gate = valid&(lidx.ne(2)) idx = UOp.const(dtypes.int, 0) - st = UOp(Ops.STORE, dtypes.void, (sbuf, idx, UOp.const(dtypes.float, 42))) + st = UOp(Ops.STORE, dtypes.void, (sbuf.index(idx), UOp.const(dtypes.float, 42))) barrier = UOp(Ops.BARRIER, dtypes.void, (st,)) - lbuf = UOp(Ops.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, 0), barrier)) - store = UOp(Ops.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, 0), lbuf, gate)) + lbuf = UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, 0)), barrier)) + store = UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, 0), gate), lbuf)) sink = UOp(Ops.SINK, dtypes.void, (store,)) sink = full_graph_rewrite(sink) if_uops = [u for u in sink.parents if u.op is Ops.IF] @@ -683,8 +687,8 @@ class TestIFUOps(unittest.TestCase): gate = valid&(lidx.ne(2)) st = UOp(Ops.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42))) barrier = UOp(Ops.BARRIER, dtypes.void, (st,)) - lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)] - stores = [UOp(Ops.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, i), lbufs[i], gate)) for i in range(4)] + lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, i)), barrier)) for i in range(4)] + stores = [UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, i), gate), lbufs[i])) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(stores)) sink = full_graph_rewrite(sink) if_uops = [u for u in sink.parents if u.op is Ops.IF] diff --git a/test/test_uops.py b/test/test_uops.py index 47b80780d0..7c1b57d5eb 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -32,9 +32,9 @@ def _test_single_value(vals, op, dts): output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)] - loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i], uop(uops, Ops.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts)) + loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts)) alu = uop(uops, Ops.ALU, output_dtype, loads, op) - out = uop(uops, Ops.STORE, dtypes.void, (buf_store, uop(uops, Ops.CONST, dtypes.int32, (), 0), alu)) + out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)] prg = _uops_to_prg([out]) @@ -49,7 +49,7 @@ def _test_single_value_const(vals, op, dts): buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) alu = uop(uops, Ops.ALU, output_dtype, loads, op) - out = uop(uops, Ops.STORE, dtypes.void, (buf_store, uop(uops, Ops.CONST, dtypes.int32, (), 0), alu)) + out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg([out]) prg.exec([buf]) @@ -61,7 +61,7 @@ def _test_uops_result(output_dtype, uops, res): # uops = [] buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) # res = output_fn(uops) - out = uop(uops, Ops.STORE, dtypes.void, (buf_store, uop(uops, Ops.CONST, dtypes.int32, (), 0), res)) + out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), res)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg([out]) prg.exec([buf]) @@ -309,20 +309,20 @@ class TestLocalAccess(unittest.TestCase): def test_local_basic(self): uops = [] smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(local=True), (), ('smem', 16)) - st = uop(uops, Ops.STORE, dtypes.void, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 0), uop(uops, Ops.CONST, dtypes.float32, (), 42.0))) + st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0))) barr = uop(uops, Ops.BARRIER, dtypes.void, (st,)) - sres = uop(uops, Ops.LOAD, dtypes.float32, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 0), barr)) + sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr)) self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_indirect(self): uops = [] smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.int32.ptr(local=True), (), ('smem', 16)) - st1 = uop(uops, Ops.STORE, dtypes.void, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 1), uop(uops, Ops.CONST, dtypes.int32, (), 2))) - st2 = uop(uops, Ops.STORE, dtypes.void, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 2), uop(uops, Ops.CONST, dtypes.int32, (), 42))) + st1 = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 1)), uop(uops, Ops.CONST, dtypes.int32, (), 2))) + st2 = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 2)), uop(uops, Ops.CONST, dtypes.int32, (), 42))) barr = uop(uops, Ops.BARRIER, dtypes.void, (st1,st2)) - ofs = uop(uops, Ops.LOAD, dtypes.int32, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 1), barr)) - sres = uop(uops, Ops.LOAD, dtypes.int32, (smem, ofs)) + ofs = uop(uops, Ops.LOAD, dtypes.int32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 1)), barr)) + sres = uop(uops, Ops.LOAD, dtypes.int32, (smem.index(ofs),)) self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42) @unittest.skipUnless(getenv("PTX"), "This only tests assembly backends") @@ -331,7 +331,7 @@ class TestAssembly(unittest.TestCase): g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) c1 = UOp(Ops.CONST, dtypes.int, (), 2) c2 = UOp(Ops.CONST, dtypes.int, (), 3) - l1 = UOp(Ops.LOAD, dtypes.int, (g1, c1)) + l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),)) a1 = UOp(Ops.ALU, dtypes.int, (l1, c1), BinaryOps.MUL) a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.MUL) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) @@ -343,7 +343,7 @@ class TestAssembly(unittest.TestCase): g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) c1 = UOp(Ops.CONST, dtypes.int, (), 2) c2 = UOp(Ops.CONST, dtypes.int, (), 3) - l1 = UOp(Ops.LOAD, dtypes.int, (g1, c1)) + l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),)) a1 = UOp(Ops.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV) a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 931baf1405..35b7881648 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -7,18 +7,14 @@ from tinygrad.ops import UOp, Ops, simplify_valid def get_gated_load_uop(valid:UOp, idx:UOp): return UOp(Ops.LOAD, dtypes.float, ( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), - idx, - UOp.const(dtypes.float, 0.0), - valid + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx, valid), + UOp.const(dtypes.float, 0.0) )) def get_load_image_uop(image_shape:Tuple[int, ...], valid:UOp, idx:Tuple[UOp, UOp]): return UOp(Ops.LOAD, dtypes.float.vec(4), ( - UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0), - UOp(Ops.VECTORIZE, dtypes.int.vec(2), idx), - UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0),)*4), - valid + UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.int.vec(2), idx), valid), + UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4) )) def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax)) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index d50ee7805e..6b88f15216 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -16,7 +16,7 @@ import functools def render(self) -> Tuple[str, ConstType, ConstType]: # NOTE: we need STORE so the ALU op has children glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) - uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink())) + uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink())) rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1] return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 1c2562f189..30785d65ea 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import List, Tuple, cast, Optional from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import variable_to_uop -from tinygrad.dtype import dtypes +from tinygrad.dtype import dtypes, ImageDType from tinygrad.ops import KernelInfo, BinaryOps, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element from tinygrad.renderer import Renderer from tinygrad.helpers import all_int, prod, partition, flatten @@ -133,6 +133,14 @@ pm_lowerer = PatternMatcher([ (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store), ]) +def idx_load_store(x:UOp): + idx = x.src[0].index(x.src[1], x.src[3] if len(x.src) > 3 else None) + v = x.dtype.count if x.op is Ops.LOAD else x.src[2].dtype.count + if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local)) + post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is Ops.LOAD else x.src[3:]) + if x.op is Ops.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg) + return UOp(x.op, x.dtype, (idx,x.src[2])+post_mask, x.arg) + def do_reduce(ctx:List[int], root:UOp): acc = UOp(Ops.DEFINE_ACC, root.dtype, (root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],)) @@ -140,6 +148,8 @@ def do_reduce(ctx:List[int], root:UOp): return acc.assign(acc.alu(root.arg, root.src[0])) just_reduce = PatternMatcher([ + # use indexing for LOAD/STORE + (UPat((Ops.LOAD, Ops.STORE), src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store), # do reduce (UPat(Ops.REDUCE, name="root"), do_reduce), ]) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 6990dc2454..3e7c23e679 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -456,17 +456,7 @@ load_store_indexing = PatternMatcher([ UPat.var("val"))), delete_redundant_gates), ]) -def idx_load_store(x:UOp): - idx = x.src[0].index(x.src[1], x.src[3] if len(x.src) > 3 else None) - v = x.dtype.count if x.op is Ops.LOAD else x.src[2].dtype.count - if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local)) - post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is Ops.LOAD else x.src[3:]) - if x.op is Ops.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg) - return UOp(x.op, x.dtype, (idx,x.src[2])+post_mask, x.arg) - migrate_indexing = PatternMatcher([ - # use indexing for LOAD/STORE - (UPat((Ops.LOAD, Ops.STORE), src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store), # create gate MUST BE BEFORE expander (UPat(Ops.STORE, name="root"), create_gate), ])