add loads at the end (#12988)

* add loads at the end

* simpler

* late load

* tests passing

* fix matvec

* spec test passes

* fix where on load

* fix abs2

* fix more tests
This commit is contained in:
George Hotz
2025-10-30 10:42:19 +08:00
committed by GitHub
parent 4b001ec723
commit 2da02f1ae1
21 changed files with 120 additions and 101 deletions

View File

@@ -53,9 +53,7 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
idx = UOp.const(dtypes.index, 0)
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2)
ld_1 = UOp(Ops.LOAD, dtypes.int32, (buf_1.index(idx),))
ld_2 = UOp(Ops.LOAD, dtypes.int32, (buf_2.index(idx),))
alu = ld_1 + ld_2
alu = buf_1.index(idx) + buf_2.index(idx)
output_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
st_0 = UOp(Ops.STORE, dtypes.void, (output_buf.index(idx), alu))
s = UOp(Ops.SINK, dtypes.void, (st_0,))

View File

@@ -194,6 +194,7 @@ class TestDTypeALU(unittest.TestCase):
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))
@unittest.skipIf(Device.DEFAULT == "PYTHON", "TODO: fix cast inf to int32 in PYTHON")
@unittest.skip("broken on Mac")
def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32)
@unittest.skip("broken. TODO: fix it")

View File

@@ -16,12 +16,12 @@ class TestLinearizerFailure(unittest.TestCase):
c2 = UOp.range(UOp.const(dtypes.index, 784), 1, AxisType.GLOBAL)
c3 = UOp.range(UOp.const(dtypes.index, 10), 3, AxisType.GLOBAL)
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=())
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True))).load()
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True)))
c6 = UOp.range(UOp.const(dtypes.index, 6000), 1004, AxisType.REDUCE)
c7 = UOp.range(UOp.const(dtypes.index, 3750), 2006, AxisType.REDUCE)
c8 = UOp.range(UOp.const(dtypes.index, 16), 2007, AxisType.GROUP_REDUCE)
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(47040000), arg=2, src=())
c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True))).load()
c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True)))
c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.index, 6000))+c6)+((c7*UOp.const(dtypes.index, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.index, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD)
c12 = c0.index((((c1*UOp.const(dtypes.index, 7840))+(c2*UOp.const(dtypes.index, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11).end(c1, c2, c3)
ast = c12.sink(arg=KernelInfo(name='test', axis_types=(), dont_use_locals=False, applied_opts=(Opt(op=OptOps.GROUP, axis=1, arg=16),), opts_to_apply=None))

View File

@@ -12,9 +12,9 @@ class TestLinearizerFailures(unittest.TestCase):
c3 = ((c1*UOp.const(dtypes.index, 32))+c2)
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(163840), arg=1, src=())
c5 = UOp.range(UOp.const(dtypes.index, 2560), 0, AxisType.REDUCE)
c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920)))).load()
c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920))))
c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2, src=())
c8 = c7.index(c3).load()
c8 = c7.index(c3)
c9 = ((((c6+(c8*UOp.const(dtypes.float, -1.0)))*(c6+(c8*UOp.const(dtypes.float, -1.0)))).reduce(c5, arg=Ops.ADD)*UOp.const(dtypes.float, 0.000390625))+UOp.const(dtypes.float, 1e-05)).sqrt().reciprocal()
c10 = c0.index(c3).store(c9).end(c1, c2)
ast = c10.sink()

View File

@@ -199,6 +199,7 @@ class TestProfiler(unittest.TestCase):
#self.assertLess(e1.st, e2.st)
#self.assertGreater(e1.en-e1.st, e2.en-e2.st)
@unittest.skipIf(not CI, "this test is flaky locally")
@unittest.skipUnless(Device[Device.DEFAULT].graph is not None, "graph support required")
def test_graph(self):
from test.test_graph import helper_alloc_rawbuffer, helper_exec_op, helper_test_graphs

View File

@@ -34,7 +34,7 @@ def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
a = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0)
b = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtype, (b.index(idx),))
ld = b.index(idx)
alu = ld.alu(alu_op, *alu_src_uops)
store = UOp.store(a.index(idx), alu)
sink = UOp(Ops.SINK, dtypes.void, (store,))

View File

@@ -2163,8 +2163,8 @@ class TestCopyFolding(unittest.TestCase):
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_on_disk_contiguous(self):
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
with open(temp('dt_arange_4_permute_contig'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute_contig')}")
b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])

View File

@@ -376,7 +376,7 @@ class TestUOpGraph(unittest.TestCase):
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),))
ld = d1.index(idx)
alu = (ld<1).cast(dtypes.bool)
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
uops = to_uops_list([out])
@@ -386,7 +386,7 @@ class TestUOpGraph(unittest.TestCase):
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),))
ld = d1.index(idx)
alu = ld.cast(dtypes.float).cast(dtypes.float)
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
uops = to_uops_list([out])
@@ -408,7 +408,7 @@ class TestUOpGraph(unittest.TestCase):
def test_bitcast_to_same_dtype_fold(self):
for dt in dtypes.ints + dtypes.floats + (dtypes.bool,):
d0 = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), arg=0)
v = UOp(Ops.LOAD, dt, (d0.index(UOp.const(dtypes.int, 0)),))
v = d0.index(UOp.const(dtypes.int, 0))
uops = to_uops_list([v.bitcast(dt)])
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}")
@@ -420,7 +420,7 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_fold(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
ld = d0.index(ridx0.valid(ridx0<50)).load()
ld = d0.index(ridx0.valid(ridx0<50))
w = (ridx0<50).where(ld, 5)
uops = to_uops_list([w])
for u in uops:
@@ -430,7 +430,7 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_folds_swapped_branches(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
ld = d0.index(ridx0.valid((ridx0<50).logical_not())).load()
ld = d0.index(ridx0.valid((ridx0<50).logical_not()))
w = (ridx0<50).where(5, ld)
uops = to_uops_list([w])
for u in uops:
@@ -441,7 +441,7 @@ class TestUOpGraph(unittest.TestCase):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_idx = ridx0.valid((ridx0<50))
ld = d0.index(gate_idx).load().cast(dtypes.float)
ld = d0.index(gate_idx).cast(dtypes.float)
w = (ridx0<50).where(ld, 5.0)
uops = to_uops_list([w])
for u in uops:
@@ -467,11 +467,11 @@ class TestUOpGraph(unittest.TestCase):
c1 = UOp.range(UOp.const(dtypes.index, 512), 1, AxisType.LOOP)
c2 = UOp.range(UOp.const(dtypes.index, 250), 2, AxisType.LOOP)
c3 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=())
c4 = c3.index(c1).load()
c4 = c3.index(c1)
c5 = UOp.range(UOp.const(dtypes.index, 240), 0, AxisType.REDUCE)
c6 = ((c2*UOp.const(dtypes.index, 240))+c5)
c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(60000), arg=2, src=())
c8 = c7.index(c6).load()
c8 = c7.index(c6)
c9 = ((c4<0).where((c4+60000), c4)!=c6.cast(dtypes.int)).where(0, c8.cast(dtypes.uint).cast(dtypes.uchar)).reduce(c5, arg=Ops.ADD)
c10 = c0.index(((c1*UOp.const(dtypes.index, 250))+c2)).store(c9).end(c1, c2)
uops = to_uops_list([c10])
@@ -481,25 +481,25 @@ class TestUOpGraph(unittest.TestCase):
def test_in_out_of_bounds_access(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0), ptr=True),))
to_uops_list([ld0])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15)),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15), ptr=True),))
to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7)),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7), ptr=True),))
to_uops_list([ld1])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_symbolic(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_gated_store(self):
@@ -531,7 +531,7 @@ class TestUOpGraph(unittest.TestCase):
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
# Load from local memory (after the IF/barrier)
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx), if_barrier))
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
# Store to global memory
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
@@ -542,18 +542,18 @@ class TestUOpGraph(unittest.TestCase):
ridx = UOp.range(20, 0)
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16))),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16)), ptr=True),))
to_uops_list([ld0])
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
i = (ldfloat+3.14).cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16)), ptr=True),))
def test_load_cast_to_bool(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not())),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not()), ptr=True),))
to_uops_list([ld0])
@unittest.skip("Bool load is not supported yet")
@@ -562,36 +562,36 @@ class TestUOpGraph(unittest.TestCase):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),)))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
to_uops_list([ld0])
def test_out_of_bounds_off_by_one_access(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_bounds_access_with_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16))),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16)), ptr=True),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16), ptr=True),))
to_uops_list([ld0, ld1])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_symbolic_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = Variable("i", 1, 80)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_index_load(self):
@@ -599,11 +599,11 @@ class TestUOpGraph(unittest.TestCase):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8)),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32))),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8), ptr=True),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32)), ptr=True),))
to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64))),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64)), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld1])
def test_bounds_with_loaded_bool(self):
@@ -611,8 +611,8 @@ class TestUOpGraph(unittest.TestCase):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
ld0 = glbl0.index(gidx0).load()
ld1 = glbl1.index(gidx0.valid(ld0)).load()
ld0 = glbl0.index(gidx0, ptr=True).load()
ld1 = glbl1.index(gidx0.valid(ld0), ptr=True).load()
with self.assertRaises(RuntimeError): to_uops_list([ld1])
def test_fold_gated_load(self):
@@ -620,38 +620,38 @@ class TestUOpGraph(unittest.TestCase):
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
idx = UOp.const(dtypes.int, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx.valid(UOp.const(dtypes.bool, True))),))
ld0 = glbl1.index(UOp.invalid())
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int))
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load()))
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),))
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True))),))
ld0 = smem.after(barrier).index(UOp.invalid())
ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2))
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
def test_fold_gated_store(self):
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
idx0 = UOp.const(dtypes.int, 0)
idx1 = UOp.const(dtypes.int, 0)
val = UOp.const(dtypes.int, 42)
st0 = glbl.index(UOp.invalid()).store(val)
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True))).store(val)
st0 = glbl.index(UOp.invalid(), ptr=True).store(val)
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True)), ptr=True).store(val)
uops = to_uops_list([st0, st1])
# only the second store happens
self.assertEqual(len(uops), 5)
self.assertEqual(uops[-1], glbl.index(idx1).store(val))
self.assertEqual(uops[-1], glbl.index(idx1, ptr=True).store(val))
@unittest.skip("this is a uop type error")
def test_asserts_bad_gate(self):

View File

@@ -39,9 +39,9 @@ def _test_single_value(vals, op, dts):
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts))
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True)]) for i, dtype in enumerate(dts))
alu = uop(uops, op, output_dtype, loads)
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu))
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)]
prg = _uops_to_prg([out])
@@ -338,7 +338,7 @@ class TestLocalAccess(unittest.TestCase):
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True),))
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
# NOTE: webgpu specific, since only webgpu performs bitpacking
@@ -348,7 +348,7 @@ class TestLocalAccess(unittest.TestCase):
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
sres = smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0))
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
# NOTE: webgpu specific, since only webgpu performs bitpacking
@@ -382,7 +382,7 @@ class TestAssembly(unittest.TestCase):
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),))
l1 = g1.index(c1)
a1 = UOp(Ops.MUL, dtypes.int, (l1, c1))
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
uops = to_uops_list([a1,a2], ren=Device[Device.DEFAULT].renderer)
@@ -395,7 +395,7 @@ class TestAssembly(unittest.TestCase):
for dt in (dtypes.int32, dtypes.uint32):
g = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), (), 0)
c = UOp(Ops.CONST, dt, (), 2)
l = UOp(Ops.LOAD, dt, (g.index(c),))
l = g.index(c)
a = UOp(Ops.IDIV, dt, (l, c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render(uops)
@@ -406,7 +406,7 @@ class TestAssembly(unittest.TestCase):
def test_fast_idiv_and_mod(self):
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
c = UOp(Ops.CONST, dtypes.uint, (), 3)
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
l = g.index(c)
a = UOp(Ops.IDIV, dtypes.uint, (l, c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render(uops)
@@ -458,8 +458,7 @@ class TestAssembly(unittest.TestCase):
def test_use_cmpeq(self):
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
c = UOp(Ops.CONST, dtypes.uint, (), 7)
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
comp = l.ne(c).ne(True)
comp = g.index(c).ne(c).ne(True)
uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render(uops)
ops = [x.op for x in uops]

View File

@@ -141,8 +141,8 @@ class TestUOpsStats(unittest.TestCase):
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
u1 = globl.index(o1)
u2 = globl.index(o2)
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.MUL, dtypes.int, (u1,u2))
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
@@ -151,8 +151,8 @@ class TestUOpsStats(unittest.TestCase):
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
u1 = globl.index(o1)
u2 = globl.index(o2)
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3))
uops_fma = full_rewrite(u4.sink())

View File

@@ -9,13 +9,13 @@ from test.unit.test_uop_symbolic import check_uop_against_string
def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, (
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx.valid(valid)),
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx.valid(valid), ptr=True),
UOp.const(dtypes.float, 0.0)
))
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid)),
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid), ptr=True),
UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
))

View File

@@ -11,7 +11,7 @@ class TestTranscendentalFunctions(unittest.TestCase):
# TODO: Test constant input when constant folding is fixed (or maybe test both variants)
# Load input value from a buffer to prevent constant folding
input_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.double.ptr(), arg=1, src=())
loaded_value = UOp.load(input_buf.index(UOp.const(dtypes.int, 0)), dtype=dtypes.double)
loaded_value = input_buf.index(UOp.const(dtypes.int, 0))
def eval_payne_hanek_reduction(v:float) -> tuple[float, int]:
return tuple(eval_uop(u, [(dtypes.float64, [v])]) for u in payne_hanek_reduction(loaded_value))

View File

@@ -12,7 +12,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, p
from tinygrad.uop.decompositions import get_late_rewrite_patterns
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render
ReduceContext, correct_load_store, pm_render, pm_add_loads
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
@@ -59,6 +59,11 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# add gpu dims (late). this works after devectorize, but it's faster here
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
# **** optimizations are done, now we lower to actual code ****
# add loads
sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)")
# devectorize (TODO: does this need opts?)
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing

View File

@@ -80,7 +80,7 @@ def add_gpudims(ctx:Renderer, s:UOp):
subs = {}
for r in s_topo:
# look for local INDEXes that are not used in the GLOBAL store, then add them as an INVALID
if r.op is Ops.STORE and r.src[0].ptrdtype.addrspace == AddrSpace.GLOBAL:
if r.op is Ops.STORE and r.src[0].src[0].ptrdtype.addrspace == AddrSpace.GLOBAL:
idx = r.src[0]
missing_locals = [all_ranges[rng] for rng in local_dims if all_ranges[rng] not in idx.ranges]
if len(missing_locals):

View File

@@ -2,7 +2,7 @@ from typing import Any, cast
import functools, operator, itertools
from collections import defaultdict
from dataclasses import dataclass
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat, invalid_gate
from tinygrad.helpers import getenv, flatten, AMX, prod
@@ -12,7 +12,7 @@ from tinygrad.renderer import Renderer
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
idx = uop_given_valid(valid, start_idx)
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid))
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True)
# wait for it to be image indexed before running simplification
if start_idx.dtype.count != 2: return None
@@ -43,7 +43,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
if not drop_stmt and idx is start_idx: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
return buf.index(idx.valid(new_valid) if new_valid is not None else idx)
return buf.index(idx.valid(new_valid) if new_valid is not None else idx, ptr=True)
load_store_indexing = PatternMatcher([
@@ -52,7 +52,7 @@ load_store_indexing = PatternMatcher([
# simplify away long after index has been lowered
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x", dtypes.long), UPat.var("c", dtypes.bool))), lambda buf,x,c: simplify_valid_load(buf, x, c)),
# drop true gate
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x, ptr=True)),
])
# ***** load/store grouping *****
@@ -60,7 +60,7 @@ load_store_indexing = PatternMatcher([
def expand_index(buf:UOp, vec:UOp):
if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx()
# generate the individual indexes
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i)) for i in range(vec.dtype.count)]),
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), ptr=True) for i in range(vec.dtype.count)]),
symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
# extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
@@ -163,7 +163,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
# with 1 at the end of the lengths list, this will always hit
for fold_length in lengths:
if global_offset+fold_length > sz: continue
lidx = buf.index((offset + global_offset).valid(mask))
lidx = buf.index((offset + global_offset).valid(mask), ptr=True)
if fold_length > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(fold_length).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
@@ -229,7 +229,7 @@ def no_vectorized_buf(buf:UOp):
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
cnt = cast.dtype.count
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))))
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))), ptr=True)
def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
cnt = cast.dtype.count
@@ -237,7 +237,7 @@ def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
input_gep = bcast.arg if bcast.op is Ops.GEP else ([0]*precnt)
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
sum_arg = tuple(flatten([[i+y for y in input_gep] for i in range(cnt)]))
return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg))
return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg), ptr=True)
devectorize_buf_and_index = PatternMatcher([
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
@@ -302,11 +302,11 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,))
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \
acc.index(UOp.const(dtypes.int, 0)).store(identity)
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0)).load()] + lst # put acc as the first element
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0))] + lst # put acc as the first element
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
if len(reduce_range) == 0: return ret
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0)).load()
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0))
pm_reduce = PatternMatcher([
# REDUCE -> DEFINE_ACC+ASSIGN
@@ -315,3 +315,14 @@ pm_reduce = PatternMatcher([
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
])+sym
# add loads
pm_add_loads = PatternMatcher([
# add loads to non ptr index
(UPat(Ops.INDEX, name="idx"), lambda idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else
idx.replace(dtype=idx.src[0].dtype).load(dtype=idx.dtype.base)),
# remove loads from stores
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
])

View File

@@ -64,8 +64,8 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
if k.ren.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.ren.has_shared and \
(mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
idx0, idx1 = mulop.src[0].src[0].src[1].get_idx(), mulop.src[1].src[0].src[1].get_idx()
(mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.INDEX and mulop.src[1].op is Ops.INDEX:
idx0, idx1 = mulop.src[0].src[1].get_idx(), mulop.src[1].src[1].get_idx()
if k.ranges_of(AxisType.REDUCE):
first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0]
if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges):

View File

@@ -55,7 +55,7 @@ def do_substitute(ctx, x: UOp):
return ret
def dont_sub_ranges_for_image(ctx, x:UOp):
if isinstance(x.src[0].dtype, ImageDType):
if isinstance(x.src[0].src[0].dtype, ImageDType):
for s in x.src[0].ranges: ctx[s] = None
pm_split_ranges = PatternMatcher([
@@ -129,7 +129,7 @@ def reduce_load_collapse(red:UOp): return reduce_collapse(red, pm=pm_reduce_load
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
pm_reduce_simplify = pm_reduce_unparented + PatternMatcher([(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),])
# remove REDUCE on load, comes from indexing a tensor with another tensor
def no_load(u:UOp) -> bool: return not any(x.op is Ops.LOAD for x in u.backward_slice_with_self)
def no_load(u:UOp) -> bool: return not any(x.op is Ops.INDEX for x in u.backward_slice_with_self)
pm_load_collapse = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_load_collapse),
# we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse

View File

@@ -435,9 +435,9 @@ rangeify_codegen = PatternMatcher([
# add loads to non ptr indexes
# TODO: this can be moved into codegen?
(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg"))
.f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
#(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg"))
# .f(Ops.INDEX, name="idx", allow_any_len=True),
# lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
# fix broadcast dtype
(UPat(Ops.AFTER, name="a").broadcast(name="b"), lambda a,b: a.broadcast(len(b.src))),

View File

@@ -339,8 +339,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, *srcs:UOp|None, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, *idx): return self.index(*idx)
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
@@ -743,6 +743,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
ret = graph_rewrite(self.simplify() if simplify else self, renderer if pm is None else pm)
return ret.arg if ret.op is Ops.NOOP else str(ret)
def pyrender(self): return pyrender(self)
@dataclass(frozen=True)
class KernelInfo:
name: str = "test" # name of the kernel
@@ -1199,10 +1201,11 @@ pm_lower_index_dtype = PatternMatcher([
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.index), UPat.cvar("val").cast(dtypes.index))), lambda var,val: var.bind(val).cast(dtypes.index)),
(UPat(Ops.CAST, src=(UPat(name="x").cast(dtypes.index),), name="c"), lambda x,c: x.cast(c.dtype)),
# lower Invalid
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond)),
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
# remove hanging casts
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
@@ -1279,8 +1282,9 @@ pm_pyrender_extra = PatternMatcher([
"UOp.range("+', '.join([str(c.arg)] + [str(y) for y in x.arg])+
(f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"),
# TODO: index shouldn't mismatch dtype
(UPat(Ops.INDEX, src=(UPat(), UPat()), name="x"), lambda ctx,x:
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, dtype={x.dtype})" if x.src[0].dtype != x.dtype else None),
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+(f"{ctx[x.src[2]]}, " if len(x.src) > 2 else "")+
(f"dtype={x.dtype})" if x.src[0].dtype != x.dtype else "ptr=True)") if x.src[0].dtype.base != x.dtype else None),
# TODO: fix forced_reshape
(UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({render_marg(ctx,x)})" if x.src[0].shape == x.shape else None),
(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),

View File

@@ -133,6 +133,7 @@ shared_codegen_spec = PatternMatcher([
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
# LOAD(idx) / STORE(idx, val) / LOAD with alt value only exists in program_spec
# TODO: move LOAD to the program_spec
(UPat().index(UPat()).or_casted().load(), lambda: True),
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),

View File

@@ -413,7 +413,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
bounds[expr][int(is_upper)] = c
# don't simplify any other gates, can lead to OOB, we substitute them back later
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, dtype=u.dtype, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
# simplify uop given that valid is True
all_candidates = []
@@ -479,21 +479,20 @@ def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None:
return UOp.const(dtypes.bool, True).prod(*[c for c in cond.split_uop(Ops.AND) if c not in dropped_clauses]).where(x, i)
pm_drop_and_clauses = PatternMatcher([(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), drop_and_clauses)])
def where_on_load(l, c1, buf, x):
def where_on_load(c1, buf, x):
c2 = x.get_valid()
duplicate_clauses = [c for c in c1.split_uop(Ops.AND) if c in c2.split_uop(Ops.AND)]
# we move the condition from the where to the load _as long as_ the condtition doesn't have some range that would place it inside of a new range
# also no data dependent loads!
moved_clauses = [c for c in c1.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in x.ranges for r in c.ranges)
and all(u in x.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.LOAD)]
and all(u in x.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.INDEX)]
if not (removed:=moved_clauses+duplicate_clauses): return None
# aditionally we can drop the clause on the where if it already exists in the load
remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in c1.split_uop(Ops.AND) if c not in removed])
return remaining_clause.where(UOp.load(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2)), *l.src[1:])), 0)
return remaining_clause.where(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2))), 0)
pm_move_where_on_load = PatternMatcher([
(UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l"), 0), where_on_load),
(UPat.var("c1").where(0, UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l")),
lambda l,c1,buf,x: where_on_load(l,c1.logical_not(),buf,x)),
(UPat.var("c1").where(UPat.var("buf").index(UPat.var("x")), 0), where_on_load),
(UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)),
])
pm_simplify_valid = PatternMatcher([