mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add loads at the end (#12988)
* add loads at the end * simpler * late load * tests passing * fix matvec * spec test passes * fix where on load * fix abs2 * fix more tests
This commit is contained in:
@@ -194,6 +194,7 @@ class TestDTypeALU(unittest.TestCase):
|
||||
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
|
||||
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "TODO: fix cast inf to int32 in PYTHON")
|
||||
@unittest.skip("broken on Mac")
|
||||
def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32)
|
||||
|
||||
@unittest.skip("broken. TODO: fix it")
|
||||
|
||||
@@ -16,12 +16,12 @@ class TestLinearizerFailure(unittest.TestCase):
|
||||
c2 = UOp.range(UOp.const(dtypes.index, 784), 1, AxisType.GLOBAL)
|
||||
c3 = UOp.range(UOp.const(dtypes.index, 10), 3, AxisType.GLOBAL)
|
||||
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=())
|
||||
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True))).load()
|
||||
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True)))
|
||||
c6 = UOp.range(UOp.const(dtypes.index, 6000), 1004, AxisType.REDUCE)
|
||||
c7 = UOp.range(UOp.const(dtypes.index, 3750), 2006, AxisType.REDUCE)
|
||||
c8 = UOp.range(UOp.const(dtypes.index, 16), 2007, AxisType.GROUP_REDUCE)
|
||||
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(47040000), arg=2, src=())
|
||||
c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True))).load()
|
||||
c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True)))
|
||||
c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.index, 6000))+c6)+((c7*UOp.const(dtypes.index, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.index, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD)
|
||||
c12 = c0.index((((c1*UOp.const(dtypes.index, 7840))+(c2*UOp.const(dtypes.index, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11).end(c1, c2, c3)
|
||||
ast = c12.sink(arg=KernelInfo(name='test', axis_types=(), dont_use_locals=False, applied_opts=(Opt(op=OptOps.GROUP, axis=1, arg=16),), opts_to_apply=None))
|
||||
|
||||
@@ -12,9 +12,9 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
c3 = ((c1*UOp.const(dtypes.index, 32))+c2)
|
||||
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(163840), arg=1, src=())
|
||||
c5 = UOp.range(UOp.const(dtypes.index, 2560), 0, AxisType.REDUCE)
|
||||
c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920)))).load()
|
||||
c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920))))
|
||||
c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2, src=())
|
||||
c8 = c7.index(c3).load()
|
||||
c8 = c7.index(c3)
|
||||
c9 = ((((c6+(c8*UOp.const(dtypes.float, -1.0)))*(c6+(c8*UOp.const(dtypes.float, -1.0)))).reduce(c5, arg=Ops.ADD)*UOp.const(dtypes.float, 0.000390625))+UOp.const(dtypes.float, 1e-05)).sqrt().reciprocal()
|
||||
c10 = c0.index(c3).store(c9).end(c1, c2)
|
||||
ast = c10.sink()
|
||||
|
||||
@@ -199,6 +199,7 @@ class TestProfiler(unittest.TestCase):
|
||||
#self.assertLess(e1.st, e2.st)
|
||||
#self.assertGreater(e1.en-e1.st, e2.en-e2.st)
|
||||
|
||||
@unittest.skipIf(not CI, "this test is flaky locally")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].graph is not None, "graph support required")
|
||||
def test_graph(self):
|
||||
from test.test_graph import helper_alloc_rawbuffer, helper_exec_op, helper_test_graphs
|
||||
|
||||
@@ -34,7 +34,7 @@ def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0)
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(Ops.LOAD, dtype, (b.index(idx),))
|
||||
ld = b.index(idx)
|
||||
alu = ld.alu(alu_op, *alu_src_uops)
|
||||
store = UOp.store(a.index(idx), alu)
|
||||
sink = UOp(Ops.SINK, dtypes.void, (store,))
|
||||
|
||||
@@ -2163,8 +2163,8 @@ class TestCopyFolding(unittest.TestCase):
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
def test_permute_on_disk_contiguous(self):
|
||||
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
|
||||
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
|
||||
with open(temp('dt_arange_4_permute_contig'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
|
||||
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute_contig')}")
|
||||
b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU")
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
@@ -376,7 +376,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
|
||||
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),))
|
||||
ld = d1.index(idx)
|
||||
alu = (ld<1).cast(dtypes.bool)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
|
||||
uops = to_uops_list([out])
|
||||
@@ -386,7 +386,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
|
||||
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),))
|
||||
ld = d1.index(idx)
|
||||
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
|
||||
uops = to_uops_list([out])
|
||||
@@ -408,7 +408,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_bitcast_to_same_dtype_fold(self):
|
||||
for dt in dtypes.ints + dtypes.floats + (dtypes.bool,):
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), arg=0)
|
||||
v = UOp(Ops.LOAD, dt, (d0.index(UOp.const(dtypes.int, 0)),))
|
||||
v = d0.index(UOp.const(dtypes.int, 0))
|
||||
uops = to_uops_list([v.bitcast(dt)])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}")
|
||||
|
||||
@@ -420,7 +420,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_where_on_gated_load_fold(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
|
||||
ld = d0.index(ridx0.valid(ridx0<50)).load()
|
||||
ld = d0.index(ridx0.valid(ridx0<50))
|
||||
w = (ridx0<50).where(ld, 5)
|
||||
uops = to_uops_list([w])
|
||||
for u in uops:
|
||||
@@ -430,7 +430,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_where_on_gated_load_folds_swapped_branches(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
|
||||
ld = d0.index(ridx0.valid((ridx0<50).logical_not())).load()
|
||||
ld = d0.index(ridx0.valid((ridx0<50).logical_not()))
|
||||
w = (ridx0<50).where(5, ld)
|
||||
uops = to_uops_list([w])
|
||||
for u in uops:
|
||||
@@ -441,7 +441,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
gate_idx = ridx0.valid((ridx0<50))
|
||||
ld = d0.index(gate_idx).load().cast(dtypes.float)
|
||||
ld = d0.index(gate_idx).cast(dtypes.float)
|
||||
w = (ridx0<50).where(ld, 5.0)
|
||||
uops = to_uops_list([w])
|
||||
for u in uops:
|
||||
@@ -467,11 +467,11 @@ class TestUOpGraph(unittest.TestCase):
|
||||
c1 = UOp.range(UOp.const(dtypes.index, 512), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.index, 250), 2, AxisType.LOOP)
|
||||
c3 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=())
|
||||
c4 = c3.index(c1).load()
|
||||
c4 = c3.index(c1)
|
||||
c5 = UOp.range(UOp.const(dtypes.index, 240), 0, AxisType.REDUCE)
|
||||
c6 = ((c2*UOp.const(dtypes.index, 240))+c5)
|
||||
c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(60000), arg=2, src=())
|
||||
c8 = c7.index(c6).load()
|
||||
c8 = c7.index(c6)
|
||||
c9 = ((c4<0).where((c4+60000), c4)!=c6.cast(dtypes.int)).where(0, c8.cast(dtypes.uint).cast(dtypes.uchar)).reduce(c5, arg=Ops.ADD)
|
||||
c10 = c0.index(((c1*UOp.const(dtypes.index, 250))+c2)).store(c9).end(c1, c2)
|
||||
uops = to_uops_list([c10])
|
||||
@@ -481,25 +481,25 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_in_out_of_bounds_access(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15)),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15), ptr=True),))
|
||||
to_uops_list([ld1])
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7)),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7), ptr=True),))
|
||||
to_uops_list([ld1])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_symbolic(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_gated_store(self):
|
||||
@@ -531,7 +531,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
|
||||
|
||||
# Load from local memory (after the IF/barrier)
|
||||
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx), if_barrier))
|
||||
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
|
||||
|
||||
# Store to global memory
|
||||
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
|
||||
@@ -542,18 +542,18 @@ class TestUOpGraph(unittest.TestCase):
|
||||
ridx = UOp.range(20, 0)
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16))),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16)), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
|
||||
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
|
||||
i = (ldfloat+3.14).cast(dtypes.int)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16)), ptr=True),))
|
||||
|
||||
def test_load_cast_to_bool(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
||||
ridx = UOp.range(20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not())),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not()), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
@unittest.skip("Bool load is not supported yet")
|
||||
@@ -562,36 +562,36 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||
ridx = UOp.range(20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),)))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
|
||||
to_uops_list([ld0])
|
||||
|
||||
def test_out_of_bounds_off_by_one_access(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_bounds_access_with_mask(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16))),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16)), ptr=True),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16), ptr=True),))
|
||||
to_uops_list([ld0, ld1])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_symbolic_mask(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
i = Variable("i", 1, 80)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_index_load(self):
|
||||
@@ -599,11 +599,11 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8)),)).cast(dtypes.index)
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32))),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8), ptr=True),)).cast(dtypes.index)
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32)), ptr=True),))
|
||||
to_uops_list([ld1])
|
||||
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64))),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64)), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld1])
|
||||
|
||||
def test_bounds_with_loaded_bool(self):
|
||||
@@ -611,8 +611,8 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
|
||||
ld0 = glbl0.index(gidx0).load()
|
||||
ld1 = glbl1.index(gidx0.valid(ld0)).load()
|
||||
ld0 = glbl0.index(gidx0, ptr=True).load()
|
||||
ld1 = glbl1.index(gidx0.valid(ld0), ptr=True).load()
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld1])
|
||||
|
||||
def test_fold_gated_load(self):
|
||||
@@ -620,38 +620,38 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx.valid(UOp.const(dtypes.bool, True))),))
|
||||
ld0 = glbl1.index(UOp.invalid())
|
||||
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
||||
ld0 = uops[-1].src[-1]
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int))
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load()))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True))),))
|
||||
ld0 = smem.after(barrier).index(UOp.invalid())
|
||||
ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
|
||||
|
||||
ld0 = uops[-1].src[-1]
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2))
|
||||
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
idx0 = UOp.const(dtypes.int, 0)
|
||||
idx1 = UOp.const(dtypes.int, 0)
|
||||
val = UOp.const(dtypes.int, 42)
|
||||
st0 = glbl.index(UOp.invalid()).store(val)
|
||||
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True))).store(val)
|
||||
st0 = glbl.index(UOp.invalid(), ptr=True).store(val)
|
||||
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True)), ptr=True).store(val)
|
||||
uops = to_uops_list([st0, st1])
|
||||
# only the second store happens
|
||||
self.assertEqual(len(uops), 5)
|
||||
self.assertEqual(uops[-1], glbl.index(idx1).store(val))
|
||||
self.assertEqual(uops[-1], glbl.index(idx1, ptr=True).store(val))
|
||||
|
||||
@unittest.skip("this is a uop type error")
|
||||
def test_asserts_bad_gate(self):
|
||||
|
||||
@@ -39,9 +39,9 @@ def _test_single_value(vals, op, dts):
|
||||
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||
buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
|
||||
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts))
|
||||
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True)]) for i, dtype in enumerate(dts))
|
||||
alu = uop(uops, op, output_dtype, loads)
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu))
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True), alu))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)]
|
||||
prg = _uops_to_prg([out])
|
||||
@@ -338,7 +338,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True),))
|
||||
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
|
||||
|
||||
# NOTE: webgpu specific, since only webgpu performs bitpacking
|
||||
@@ -348,7 +348,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
|
||||
sres = smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0))
|
||||
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
|
||||
|
||||
# NOTE: webgpu specific, since only webgpu performs bitpacking
|
||||
@@ -382,7 +382,7 @@ class TestAssembly(unittest.TestCase):
|
||||
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
|
||||
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
|
||||
l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),))
|
||||
l1 = g1.index(c1)
|
||||
a1 = UOp(Ops.MUL, dtypes.int, (l1, c1))
|
||||
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
|
||||
uops = to_uops_list([a1,a2], ren=Device[Device.DEFAULT].renderer)
|
||||
@@ -395,7 +395,7 @@ class TestAssembly(unittest.TestCase):
|
||||
for dt in (dtypes.int32, dtypes.uint32):
|
||||
g = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), (), 0)
|
||||
c = UOp(Ops.CONST, dt, (), 2)
|
||||
l = UOp(Ops.LOAD, dt, (g.index(c),))
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.IDIV, dt, (l, c))
|
||||
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render(uops)
|
||||
@@ -406,7 +406,7 @@ class TestAssembly(unittest.TestCase):
|
||||
def test_fast_idiv_and_mod(self):
|
||||
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
|
||||
c = UOp(Ops.CONST, dtypes.uint, (), 3)
|
||||
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.IDIV, dtypes.uint, (l, c))
|
||||
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render(uops)
|
||||
@@ -458,8 +458,7 @@ class TestAssembly(unittest.TestCase):
|
||||
def test_use_cmpeq(self):
|
||||
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
|
||||
c = UOp(Ops.CONST, dtypes.uint, (), 7)
|
||||
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
|
||||
comp = l.ne(c).ne(True)
|
||||
comp = g.index(c).ne(c).ne(True)
|
||||
uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render(uops)
|
||||
ops = [x.op for x in uops]
|
||||
|
||||
@@ -141,8 +141,8 @@ class TestUOpsStats(unittest.TestCase):
|
||||
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
|
||||
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
|
||||
u1 = globl.index(o1)
|
||||
u2 = globl.index(o2)
|
||||
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(Ops.MUL, dtypes.int, (u1,u2))
|
||||
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
|
||||
@@ -151,8 +151,8 @@ class TestUOpsStats(unittest.TestCase):
|
||||
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
|
||||
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
|
||||
u1 = globl.index(o1)
|
||||
u2 = globl.index(o2)
|
||||
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3))
|
||||
uops_fma = full_rewrite(u4.sink())
|
||||
|
||||
@@ -9,13 +9,13 @@ from test.unit.test_uop_symbolic import check_uop_against_string
|
||||
|
||||
def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||
return UOp(Ops.LOAD, dtypes.float, (
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx.valid(valid)),
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx.valid(valid), ptr=True),
|
||||
UOp.const(dtypes.float, 0.0)
|
||||
))
|
||||
|
||||
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
|
||||
return UOp(Ops.LOAD, dtypes.float.vec(4), (
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid)),
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid), ptr=True),
|
||||
UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
|
||||
))
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ class TestTranscendentalFunctions(unittest.TestCase):
|
||||
# TODO: Test constant input when constant folding is fixed (or maybe test both variants)
|
||||
# Load input value from a buffer to prevent constant folding
|
||||
input_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.double.ptr(), arg=1, src=())
|
||||
loaded_value = UOp.load(input_buf.index(UOp.const(dtypes.int, 0)), dtype=dtypes.double)
|
||||
loaded_value = input_buf.index(UOp.const(dtypes.int, 0))
|
||||
def eval_payne_hanek_reduction(v:float) -> tuple[float, int]:
|
||||
return tuple(eval_uop(u, [(dtypes.float64, [v])]) for u in payne_hanek_reduction(loaded_value))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user