mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
add loads at the end (#12988)
* add loads at the end * simpler * late load * tests passing * fix matvec * spec test passes * fix where on load * fix abs2 * fix more tests
This commit is contained in:
@@ -53,9 +53,7 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
|
||||
idx = UOp.const(dtypes.index, 0)
|
||||
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)
|
||||
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2)
|
||||
ld_1 = UOp(Ops.LOAD, dtypes.int32, (buf_1.index(idx),))
|
||||
ld_2 = UOp(Ops.LOAD, dtypes.int32, (buf_2.index(idx),))
|
||||
alu = ld_1 + ld_2
|
||||
alu = buf_1.index(idx) + buf_2.index(idx)
|
||||
output_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
st_0 = UOp(Ops.STORE, dtypes.void, (output_buf.index(idx), alu))
|
||||
s = UOp(Ops.SINK, dtypes.void, (st_0,))
|
||||
|
||||
@@ -194,6 +194,7 @@ class TestDTypeALU(unittest.TestCase):
|
||||
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
|
||||
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "TODO: fix cast inf to int32 in PYTHON")
|
||||
@unittest.skip("broken on Mac")
|
||||
def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32)
|
||||
|
||||
@unittest.skip("broken. TODO: fix it")
|
||||
|
||||
@@ -16,12 +16,12 @@ class TestLinearizerFailure(unittest.TestCase):
|
||||
c2 = UOp.range(UOp.const(dtypes.index, 784), 1, AxisType.GLOBAL)
|
||||
c3 = UOp.range(UOp.const(dtypes.index, 10), 3, AxisType.GLOBAL)
|
||||
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=())
|
||||
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True))).load()
|
||||
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True)))
|
||||
c6 = UOp.range(UOp.const(dtypes.index, 6000), 1004, AxisType.REDUCE)
|
||||
c7 = UOp.range(UOp.const(dtypes.index, 3750), 2006, AxisType.REDUCE)
|
||||
c8 = UOp.range(UOp.const(dtypes.index, 16), 2007, AxisType.GROUP_REDUCE)
|
||||
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(47040000), arg=2, src=())
|
||||
c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True))).load()
|
||||
c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True)))
|
||||
c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.index, 6000))+c6)+((c7*UOp.const(dtypes.index, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.index, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD)
|
||||
c12 = c0.index((((c1*UOp.const(dtypes.index, 7840))+(c2*UOp.const(dtypes.index, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11).end(c1, c2, c3)
|
||||
ast = c12.sink(arg=KernelInfo(name='test', axis_types=(), dont_use_locals=False, applied_opts=(Opt(op=OptOps.GROUP, axis=1, arg=16),), opts_to_apply=None))
|
||||
|
||||
@@ -12,9 +12,9 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
c3 = ((c1*UOp.const(dtypes.index, 32))+c2)
|
||||
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(163840), arg=1, src=())
|
||||
c5 = UOp.range(UOp.const(dtypes.index, 2560), 0, AxisType.REDUCE)
|
||||
c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920)))).load()
|
||||
c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920))))
|
||||
c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2, src=())
|
||||
c8 = c7.index(c3).load()
|
||||
c8 = c7.index(c3)
|
||||
c9 = ((((c6+(c8*UOp.const(dtypes.float, -1.0)))*(c6+(c8*UOp.const(dtypes.float, -1.0)))).reduce(c5, arg=Ops.ADD)*UOp.const(dtypes.float, 0.000390625))+UOp.const(dtypes.float, 1e-05)).sqrt().reciprocal()
|
||||
c10 = c0.index(c3).store(c9).end(c1, c2)
|
||||
ast = c10.sink()
|
||||
|
||||
@@ -199,6 +199,7 @@ class TestProfiler(unittest.TestCase):
|
||||
#self.assertLess(e1.st, e2.st)
|
||||
#self.assertGreater(e1.en-e1.st, e2.en-e2.st)
|
||||
|
||||
@unittest.skipIf(not CI, "this test is flaky locally")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].graph is not None, "graph support required")
|
||||
def test_graph(self):
|
||||
from test.test_graph import helper_alloc_rawbuffer, helper_exec_op, helper_test_graphs
|
||||
|
||||
@@ -34,7 +34,7 @@ def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0)
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(Ops.LOAD, dtype, (b.index(idx),))
|
||||
ld = b.index(idx)
|
||||
alu = ld.alu(alu_op, *alu_src_uops)
|
||||
store = UOp.store(a.index(idx), alu)
|
||||
sink = UOp(Ops.SINK, dtypes.void, (store,))
|
||||
|
||||
@@ -2163,8 +2163,8 @@ class TestCopyFolding(unittest.TestCase):
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
def test_permute_on_disk_contiguous(self):
|
||||
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
|
||||
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
|
||||
with open(temp('dt_arange_4_permute_contig'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
|
||||
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute_contig')}")
|
||||
b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU")
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
@@ -376,7 +376,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
|
||||
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),))
|
||||
ld = d1.index(idx)
|
||||
alu = (ld<1).cast(dtypes.bool)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
|
||||
uops = to_uops_list([out])
|
||||
@@ -386,7 +386,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
|
||||
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),))
|
||||
ld = d1.index(idx)
|
||||
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
|
||||
uops = to_uops_list([out])
|
||||
@@ -408,7 +408,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_bitcast_to_same_dtype_fold(self):
|
||||
for dt in dtypes.ints + dtypes.floats + (dtypes.bool,):
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), arg=0)
|
||||
v = UOp(Ops.LOAD, dt, (d0.index(UOp.const(dtypes.int, 0)),))
|
||||
v = d0.index(UOp.const(dtypes.int, 0))
|
||||
uops = to_uops_list([v.bitcast(dt)])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}")
|
||||
|
||||
@@ -420,7 +420,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_where_on_gated_load_fold(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
|
||||
ld = d0.index(ridx0.valid(ridx0<50)).load()
|
||||
ld = d0.index(ridx0.valid(ridx0<50))
|
||||
w = (ridx0<50).where(ld, 5)
|
||||
uops = to_uops_list([w])
|
||||
for u in uops:
|
||||
@@ -430,7 +430,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_where_on_gated_load_folds_swapped_branches(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
|
||||
ld = d0.index(ridx0.valid((ridx0<50).logical_not())).load()
|
||||
ld = d0.index(ridx0.valid((ridx0<50).logical_not()))
|
||||
w = (ridx0<50).where(5, ld)
|
||||
uops = to_uops_list([w])
|
||||
for u in uops:
|
||||
@@ -441,7 +441,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
gate_idx = ridx0.valid((ridx0<50))
|
||||
ld = d0.index(gate_idx).load().cast(dtypes.float)
|
||||
ld = d0.index(gate_idx).cast(dtypes.float)
|
||||
w = (ridx0<50).where(ld, 5.0)
|
||||
uops = to_uops_list([w])
|
||||
for u in uops:
|
||||
@@ -467,11 +467,11 @@ class TestUOpGraph(unittest.TestCase):
|
||||
c1 = UOp.range(UOp.const(dtypes.index, 512), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.index, 250), 2, AxisType.LOOP)
|
||||
c3 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=())
|
||||
c4 = c3.index(c1).load()
|
||||
c4 = c3.index(c1)
|
||||
c5 = UOp.range(UOp.const(dtypes.index, 240), 0, AxisType.REDUCE)
|
||||
c6 = ((c2*UOp.const(dtypes.index, 240))+c5)
|
||||
c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(60000), arg=2, src=())
|
||||
c8 = c7.index(c6).load()
|
||||
c8 = c7.index(c6)
|
||||
c9 = ((c4<0).where((c4+60000), c4)!=c6.cast(dtypes.int)).where(0, c8.cast(dtypes.uint).cast(dtypes.uchar)).reduce(c5, arg=Ops.ADD)
|
||||
c10 = c0.index(((c1*UOp.const(dtypes.index, 250))+c2)).store(c9).end(c1, c2)
|
||||
uops = to_uops_list([c10])
|
||||
@@ -481,25 +481,25 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_in_out_of_bounds_access(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15)),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15), ptr=True),))
|
||||
to_uops_list([ld1])
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7)),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7), ptr=True),))
|
||||
to_uops_list([ld1])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_symbolic(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_gated_store(self):
|
||||
@@ -531,7 +531,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
|
||||
|
||||
# Load from local memory (after the IF/barrier)
|
||||
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx), if_barrier))
|
||||
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
|
||||
|
||||
# Store to global memory
|
||||
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
|
||||
@@ -542,18 +542,18 @@ class TestUOpGraph(unittest.TestCase):
|
||||
ridx = UOp.range(20, 0)
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16))),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16)), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
|
||||
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
|
||||
i = (ldfloat+3.14).cast(dtypes.int)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16)), ptr=True),))
|
||||
|
||||
def test_load_cast_to_bool(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
||||
ridx = UOp.range(20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not())),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not()), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
@unittest.skip("Bool load is not supported yet")
|
||||
@@ -562,36 +562,36 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||
ridx = UOp.range(20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),)))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
|
||||
to_uops_list([ld0])
|
||||
|
||||
def test_out_of_bounds_off_by_one_access(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_bounds_access_with_mask(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16))),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16)), ptr=True),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16), ptr=True),))
|
||||
to_uops_list([ld0, ld1])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_symbolic_mask(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
i = Variable("i", 1, 80)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_index_load(self):
|
||||
@@ -599,11 +599,11 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8)),)).cast(dtypes.index)
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32))),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8), ptr=True),)).cast(dtypes.index)
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32)), ptr=True),))
|
||||
to_uops_list([ld1])
|
||||
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64))),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64)), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld1])
|
||||
|
||||
def test_bounds_with_loaded_bool(self):
|
||||
@@ -611,8 +611,8 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
|
||||
ld0 = glbl0.index(gidx0).load()
|
||||
ld1 = glbl1.index(gidx0.valid(ld0)).load()
|
||||
ld0 = glbl0.index(gidx0, ptr=True).load()
|
||||
ld1 = glbl1.index(gidx0.valid(ld0), ptr=True).load()
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld1])
|
||||
|
||||
def test_fold_gated_load(self):
|
||||
@@ -620,38 +620,38 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx.valid(UOp.const(dtypes.bool, True))),))
|
||||
ld0 = glbl1.index(UOp.invalid())
|
||||
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
||||
ld0 = uops[-1].src[-1]
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int))
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load()))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True))),))
|
||||
ld0 = smem.after(barrier).index(UOp.invalid())
|
||||
ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
|
||||
|
||||
ld0 = uops[-1].src[-1]
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2))
|
||||
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
idx0 = UOp.const(dtypes.int, 0)
|
||||
idx1 = UOp.const(dtypes.int, 0)
|
||||
val = UOp.const(dtypes.int, 42)
|
||||
st0 = glbl.index(UOp.invalid()).store(val)
|
||||
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True))).store(val)
|
||||
st0 = glbl.index(UOp.invalid(), ptr=True).store(val)
|
||||
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True)), ptr=True).store(val)
|
||||
uops = to_uops_list([st0, st1])
|
||||
# only the second store happens
|
||||
self.assertEqual(len(uops), 5)
|
||||
self.assertEqual(uops[-1], glbl.index(idx1).store(val))
|
||||
self.assertEqual(uops[-1], glbl.index(idx1, ptr=True).store(val))
|
||||
|
||||
@unittest.skip("this is a uop type error")
|
||||
def test_asserts_bad_gate(self):
|
||||
|
||||
@@ -39,9 +39,9 @@ def _test_single_value(vals, op, dts):
|
||||
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||
buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
|
||||
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts))
|
||||
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True)]) for i, dtype in enumerate(dts))
|
||||
alu = uop(uops, op, output_dtype, loads)
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu))
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True), alu))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)]
|
||||
prg = _uops_to_prg([out])
|
||||
@@ -338,7 +338,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True),))
|
||||
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
|
||||
|
||||
# NOTE: webgpu specific, since only webgpu performs bitpacking
|
||||
@@ -348,7 +348,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
|
||||
sres = smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0))
|
||||
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
|
||||
|
||||
# NOTE: webgpu specific, since only webgpu performs bitpacking
|
||||
@@ -382,7 +382,7 @@ class TestAssembly(unittest.TestCase):
|
||||
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
|
||||
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
|
||||
l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),))
|
||||
l1 = g1.index(c1)
|
||||
a1 = UOp(Ops.MUL, dtypes.int, (l1, c1))
|
||||
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
|
||||
uops = to_uops_list([a1,a2], ren=Device[Device.DEFAULT].renderer)
|
||||
@@ -395,7 +395,7 @@ class TestAssembly(unittest.TestCase):
|
||||
for dt in (dtypes.int32, dtypes.uint32):
|
||||
g = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), (), 0)
|
||||
c = UOp(Ops.CONST, dt, (), 2)
|
||||
l = UOp(Ops.LOAD, dt, (g.index(c),))
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.IDIV, dt, (l, c))
|
||||
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render(uops)
|
||||
@@ -406,7 +406,7 @@ class TestAssembly(unittest.TestCase):
|
||||
def test_fast_idiv_and_mod(self):
|
||||
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
|
||||
c = UOp(Ops.CONST, dtypes.uint, (), 3)
|
||||
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.IDIV, dtypes.uint, (l, c))
|
||||
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render(uops)
|
||||
@@ -458,8 +458,7 @@ class TestAssembly(unittest.TestCase):
|
||||
def test_use_cmpeq(self):
|
||||
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
|
||||
c = UOp(Ops.CONST, dtypes.uint, (), 7)
|
||||
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
|
||||
comp = l.ne(c).ne(True)
|
||||
comp = g.index(c).ne(c).ne(True)
|
||||
uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render(uops)
|
||||
ops = [x.op for x in uops]
|
||||
|
||||
@@ -141,8 +141,8 @@ class TestUOpsStats(unittest.TestCase):
|
||||
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
|
||||
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
|
||||
u1 = globl.index(o1)
|
||||
u2 = globl.index(o2)
|
||||
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(Ops.MUL, dtypes.int, (u1,u2))
|
||||
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
|
||||
@@ -151,8 +151,8 @@ class TestUOpsStats(unittest.TestCase):
|
||||
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
|
||||
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
|
||||
u1 = globl.index(o1)
|
||||
u2 = globl.index(o2)
|
||||
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3))
|
||||
uops_fma = full_rewrite(u4.sink())
|
||||
|
||||
@@ -9,13 +9,13 @@ from test.unit.test_uop_symbolic import check_uop_against_string
|
||||
|
||||
def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||
return UOp(Ops.LOAD, dtypes.float, (
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx.valid(valid)),
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx.valid(valid), ptr=True),
|
||||
UOp.const(dtypes.float, 0.0)
|
||||
))
|
||||
|
||||
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
|
||||
return UOp(Ops.LOAD, dtypes.float.vec(4), (
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid)),
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid), ptr=True),
|
||||
UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
|
||||
))
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ class TestTranscendentalFunctions(unittest.TestCase):
|
||||
# TODO: Test constant input when constant folding is fixed (or maybe test both variants)
|
||||
# Load input value from a buffer to prevent constant folding
|
||||
input_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.double.ptr(), arg=1, src=())
|
||||
loaded_value = UOp.load(input_buf.index(UOp.const(dtypes.int, 0)), dtype=dtypes.double)
|
||||
loaded_value = input_buf.index(UOp.const(dtypes.int, 0))
|
||||
def eval_payne_hanek_reduction(v:float) -> tuple[float, int]:
|
||||
return tuple(eval_uop(u, [(dtypes.float64, [v])]) for u in payne_hanek_reduction(loaded_value))
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, p
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
|
||||
@@ -59,6 +59,11 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
# add gpu dims (late). this works after devectorize, but it's faster here
|
||||
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
|
||||
|
||||
# **** optimizations are done, now we lower to actual code ****
|
||||
|
||||
# add loads
|
||||
sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)")
|
||||
|
||||
# devectorize (TODO: does this need opts?)
|
||||
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
|
||||
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
|
||||
|
||||
@@ -80,7 +80,7 @@ def add_gpudims(ctx:Renderer, s:UOp):
|
||||
subs = {}
|
||||
for r in s_topo:
|
||||
# look for local INDEXes that are not used in the GLOBAL store, then add them as an INVALID
|
||||
if r.op is Ops.STORE and r.src[0].ptrdtype.addrspace == AddrSpace.GLOBAL:
|
||||
if r.op is Ops.STORE and r.src[0].src[0].ptrdtype.addrspace == AddrSpace.GLOBAL:
|
||||
idx = r.src[0]
|
||||
missing_locals = [all_ranges[rng] for rng in local_dims if all_ranges[rng] not in idx.ranges]
|
||||
if len(missing_locals):
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any, cast
|
||||
import functools, operator, itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat, invalid_gate
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod
|
||||
@@ -12,7 +12,7 @@ from tinygrad.renderer import Renderer
|
||||
|
||||
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
idx = uop_given_valid(valid, start_idx)
|
||||
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid))
|
||||
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True)
|
||||
|
||||
# wait for it to be image indexed before running simplification
|
||||
if start_idx.dtype.count != 2: return None
|
||||
@@ -43,7 +43,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
|
||||
if not drop_stmt and idx is start_idx: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
|
||||
return buf.index(idx.valid(new_valid) if new_valid is not None else idx)
|
||||
return buf.index(idx.valid(new_valid) if new_valid is not None else idx, ptr=True)
|
||||
|
||||
|
||||
load_store_indexing = PatternMatcher([
|
||||
@@ -52,7 +52,7 @@ load_store_indexing = PatternMatcher([
|
||||
# simplify away long after index has been lowered
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x", dtypes.long), UPat.var("c", dtypes.bool))), lambda buf,x,c: simplify_valid_load(buf, x, c)),
|
||||
# drop true gate
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x, ptr=True)),
|
||||
])
|
||||
|
||||
# ***** load/store grouping *****
|
||||
@@ -60,7 +60,7 @@ load_store_indexing = PatternMatcher([
|
||||
def expand_index(buf:UOp, vec:UOp):
|
||||
if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx()
|
||||
# generate the individual indexes
|
||||
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i)) for i in range(vec.dtype.count)]),
|
||||
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), ptr=True) for i in range(vec.dtype.count)]),
|
||||
symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
|
||||
# extract all the relevant offsets
|
||||
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
||||
@@ -163,7 +163,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
# with 1 at the end of the lengths list, this will always hit
|
||||
for fold_length in lengths:
|
||||
if global_offset+fold_length > sz: continue
|
||||
lidx = buf.index((offset + global_offset).valid(mask))
|
||||
lidx = buf.index((offset + global_offset).valid(mask), ptr=True)
|
||||
if fold_length > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(fold_length).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
|
||||
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
|
||||
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
|
||||
@@ -229,7 +229,7 @@ def no_vectorized_buf(buf:UOp):
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
|
||||
cnt = cast.dtype.count
|
||||
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
|
||||
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))))
|
||||
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))), ptr=True)
|
||||
|
||||
def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
|
||||
cnt = cast.dtype.count
|
||||
@@ -237,7 +237,7 @@ def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
|
||||
input_gep = bcast.arg if bcast.op is Ops.GEP else ([0]*precnt)
|
||||
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
|
||||
sum_arg = tuple(flatten([[i+y for y in input_gep] for i in range(cnt)]))
|
||||
return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg))
|
||||
return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg), ptr=True)
|
||||
|
||||
devectorize_buf_and_index = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
@@ -302,11 +302,11 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,))
|
||||
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \
|
||||
acc.index(UOp.const(dtypes.int, 0)).store(identity)
|
||||
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0)).load()] + lst # put acc as the first element
|
||||
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0))] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
if len(reduce_range) == 0: return ret
|
||||
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0)).load()
|
||||
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0))
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN
|
||||
@@ -315,3 +315,14 @@ pm_reduce = PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
])+sym
|
||||
|
||||
# add loads
|
||||
|
||||
pm_add_loads = PatternMatcher([
|
||||
# add loads to non ptr index
|
||||
(UPat(Ops.INDEX, name="idx"), lambda idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else
|
||||
idx.replace(dtype=idx.src[0].dtype).load(dtype=idx.dtype.base)),
|
||||
# remove loads from stores
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
|
||||
])
|
||||
|
||||
|
||||
@@ -64,8 +64,8 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
||||
if k.ren.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||||
k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.ren.has_shared and \
|
||||
(mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
||||
idx0, idx1 = mulop.src[0].src[0].src[1].get_idx(), mulop.src[1].src[0].src[1].get_idx()
|
||||
(mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.INDEX and mulop.src[1].op is Ops.INDEX:
|
||||
idx0, idx1 = mulop.src[0].src[1].get_idx(), mulop.src[1].src[1].get_idx()
|
||||
if k.ranges_of(AxisType.REDUCE):
|
||||
first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0]
|
||||
if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges):
|
||||
|
||||
@@ -55,7 +55,7 @@ def do_substitute(ctx, x: UOp):
|
||||
return ret
|
||||
|
||||
def dont_sub_ranges_for_image(ctx, x:UOp):
|
||||
if isinstance(x.src[0].dtype, ImageDType):
|
||||
if isinstance(x.src[0].src[0].dtype, ImageDType):
|
||||
for s in x.src[0].ranges: ctx[s] = None
|
||||
|
||||
pm_split_ranges = PatternMatcher([
|
||||
@@ -129,7 +129,7 @@ def reduce_load_collapse(red:UOp): return reduce_collapse(red, pm=pm_reduce_load
|
||||
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
|
||||
pm_reduce_simplify = pm_reduce_unparented + PatternMatcher([(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),])
|
||||
# remove REDUCE on load, comes from indexing a tensor with another tensor
|
||||
def no_load(u:UOp) -> bool: return not any(x.op is Ops.LOAD for x in u.backward_slice_with_self)
|
||||
def no_load(u:UOp) -> bool: return not any(x.op is Ops.INDEX for x in u.backward_slice_with_self)
|
||||
pm_load_collapse = PatternMatcher([
|
||||
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_load_collapse),
|
||||
# we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse
|
||||
|
||||
@@ -435,9 +435,9 @@ rangeify_codegen = PatternMatcher([
|
||||
|
||||
# add loads to non ptr indexes
|
||||
# TODO: this can be moved into codegen?
|
||||
(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg"))
|
||||
.f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
|
||||
#(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg"))
|
||||
# .f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
# lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
|
||||
|
||||
# fix broadcast dtype
|
||||
(UPat(Ops.AFTER, name="a").broadcast(name="b"), lambda a,b: a.broadcast(len(b.src))),
|
||||
|
||||
@@ -339,8 +339,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
||||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||
def index(self, *srcs:UOp|None, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, *idx): return self.index(*idx)
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
@@ -743,6 +743,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
ret = graph_rewrite(self.simplify() if simplify else self, renderer if pm is None else pm)
|
||||
return ret.arg if ret.op is Ops.NOOP else str(ret)
|
||||
|
||||
def pyrender(self): return pyrender(self)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelInfo:
|
||||
name: str = "test" # name of the kernel
|
||||
@@ -1199,10 +1201,11 @@ pm_lower_index_dtype = PatternMatcher([
|
||||
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.index), UPat.cvar("val").cast(dtypes.index))), lambda var,val: var.bind(val).cast(dtypes.index)),
|
||||
(UPat(Ops.CAST, src=(UPat(name="x").cast(dtypes.index),), name="c"), lambda x,c: x.cast(c.dtype)),
|
||||
# lower Invalid
|
||||
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond)),
|
||||
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
|
||||
# remove hanging casts
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
|
||||
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
|
||||
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
|
||||
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
|
||||
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
|
||||
@@ -1279,8 +1282,9 @@ pm_pyrender_extra = PatternMatcher([
|
||||
"UOp.range("+', '.join([str(c.arg)] + [str(y) for y in x.arg])+
|
||||
(f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"),
|
||||
# TODO: index shouldn't mismatch dtype
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat()), name="x"), lambda ctx,x:
|
||||
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, dtype={x.dtype})" if x.src[0].dtype != x.dtype else None),
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
|
||||
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+(f"{ctx[x.src[2]]}, " if len(x.src) > 2 else "")+
|
||||
(f"dtype={x.dtype})" if x.src[0].dtype != x.dtype else "ptr=True)") if x.src[0].dtype.base != x.dtype else None),
|
||||
# TODO: fix forced_reshape
|
||||
(UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({render_marg(ctx,x)})" if x.src[0].shape == x.shape else None),
|
||||
(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),
|
||||
|
||||
@@ -133,6 +133,7 @@ shared_codegen_spec = PatternMatcher([
|
||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
|
||||
# LOAD(idx) / STORE(idx, val) / LOAD with alt value only exists in program_spec
|
||||
# TODO: move LOAD to the program_spec
|
||||
(UPat().index(UPat()).or_casted().load(), lambda: True),
|
||||
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
|
||||
|
||||
|
||||
@@ -413,7 +413,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
|
||||
bounds[expr][int(is_upper)] = c
|
||||
|
||||
# don't simplify any other gates, can lead to OOB, we substitute them back later
|
||||
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
|
||||
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, dtype=u.dtype, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
|
||||
|
||||
# simplify uop given that valid is True
|
||||
all_candidates = []
|
||||
@@ -479,21 +479,20 @@ def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None:
|
||||
return UOp.const(dtypes.bool, True).prod(*[c for c in cond.split_uop(Ops.AND) if c not in dropped_clauses]).where(x, i)
|
||||
pm_drop_and_clauses = PatternMatcher([(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), drop_and_clauses)])
|
||||
|
||||
def where_on_load(l, c1, buf, x):
|
||||
def where_on_load(c1, buf, x):
|
||||
c2 = x.get_valid()
|
||||
duplicate_clauses = [c for c in c1.split_uop(Ops.AND) if c in c2.split_uop(Ops.AND)]
|
||||
# we move the condition from the where to the load _as long as_ the condtition doesn't have some range that would place it inside of a new range
|
||||
# also no data dependent loads!
|
||||
moved_clauses = [c for c in c1.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in x.ranges for r in c.ranges)
|
||||
and all(u in x.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.LOAD)]
|
||||
and all(u in x.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.INDEX)]
|
||||
if not (removed:=moved_clauses+duplicate_clauses): return None
|
||||
# aditionally we can drop the clause on the where if it already exists in the load
|
||||
remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in c1.split_uop(Ops.AND) if c not in removed])
|
||||
return remaining_clause.where(UOp.load(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2)), *l.src[1:])), 0)
|
||||
return remaining_clause.where(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2))), 0)
|
||||
pm_move_where_on_load = PatternMatcher([
|
||||
(UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l"), 0), where_on_load),
|
||||
(UPat.var("c1").where(0, UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l")),
|
||||
lambda l,c1,buf,x: where_on_load(l,c1.logical_not(),buf,x)),
|
||||
(UPat.var("c1").where(UPat.var("buf").index(UPat.var("x")), 0), where_on_load),
|
||||
(UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)),
|
||||
])
|
||||
|
||||
pm_simplify_valid = PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user