add loads at the end (#12988)

* add loads at the end

* simpler

* late load

* tests passing

* fix matvec

* spec test passes

* fix where on load

* fix abs2

* fix more tests
This commit is contained in:
George Hotz
2025-10-30 10:42:19 +08:00
committed by GitHub
parent 4b001ec723
commit 2da02f1ae1
21 changed files with 120 additions and 101 deletions

View File

@@ -194,6 +194,7 @@ class TestDTypeALU(unittest.TestCase):
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))
@unittest.skipIf(Device.DEFAULT == "PYTHON", "TODO: fix cast inf to int32 in PYTHON")
@unittest.skip("broken on Mac")
def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32)
@unittest.skip("broken. TODO: fix it")

View File

@@ -16,12 +16,12 @@ class TestLinearizerFailure(unittest.TestCase):
c2 = UOp.range(UOp.const(dtypes.index, 784), 1, AxisType.GLOBAL)
c3 = UOp.range(UOp.const(dtypes.index, 10), 3, AxisType.GLOBAL)
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=())
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True))).load()
c5 = c4.index(c1.valid(UOp.const(dtypes.bool, True)))
c6 = UOp.range(UOp.const(dtypes.index, 6000), 1004, AxisType.REDUCE)
c7 = UOp.range(UOp.const(dtypes.index, 3750), 2006, AxisType.REDUCE)
c8 = UOp.range(UOp.const(dtypes.index, 16), 2007, AxisType.GROUP_REDUCE)
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(47040000), arg=2, src=())
c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True))).load()
c10 = c9.index((((c3*UOp.const(dtypes.index, 4704000))+c2)+(c6*UOp.const(dtypes.index, 784))).valid(UOp.const(dtypes.bool, True)))
c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.index, 6000))+c6)+((c7*UOp.const(dtypes.index, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.index, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD)
c12 = c0.index((((c1*UOp.const(dtypes.index, 7840))+(c2*UOp.const(dtypes.index, 10)))+c3).valid(UOp.const(dtypes.bool, True))).store(c11).end(c1, c2, c3)
ast = c12.sink(arg=KernelInfo(name='test', axis_types=(), dont_use_locals=False, applied_opts=(Opt(op=OptOps.GROUP, axis=1, arg=16),), opts_to_apply=None))

View File

@@ -12,9 +12,9 @@ class TestLinearizerFailures(unittest.TestCase):
c3 = ((c1*UOp.const(dtypes.index, 32))+c2)
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(163840), arg=1, src=())
c5 = UOp.range(UOp.const(dtypes.index, 2560), 0, AxisType.REDUCE)
c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920)))).load()
c6 = c4.index(((((((c5//UOp.const(dtypes.index, 8))%UOp.const(dtypes.index, 8))*UOp.const(dtypes.index, 8))+(c5%UOp.const(dtypes.index, 8)))+(((c2*UOp.const(dtypes.index, 40))+(c5//UOp.const(dtypes.index, 64)))*UOp.const(dtypes.index, 64)))+(c1*UOp.const(dtypes.index, 81920))))
c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2, src=())
c8 = c7.index(c3).load()
c8 = c7.index(c3)
c9 = ((((c6+(c8*UOp.const(dtypes.float, -1.0)))*(c6+(c8*UOp.const(dtypes.float, -1.0)))).reduce(c5, arg=Ops.ADD)*UOp.const(dtypes.float, 0.000390625))+UOp.const(dtypes.float, 1e-05)).sqrt().reciprocal()
c10 = c0.index(c3).store(c9).end(c1, c2)
ast = c10.sink()

View File

@@ -199,6 +199,7 @@ class TestProfiler(unittest.TestCase):
#self.assertLess(e1.st, e2.st)
#self.assertGreater(e1.en-e1.st, e2.en-e2.st)
@unittest.skipIf(not CI, "this test is flaky locally")
@unittest.skipUnless(Device[Device.DEFAULT].graph is not None, "graph support required")
def test_graph(self):
from test.test_graph import helper_alloc_rawbuffer, helper_exec_op, helper_test_graphs

View File

@@ -34,7 +34,7 @@ def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
a = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0)
b = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtype, (b.index(idx),))
ld = b.index(idx)
alu = ld.alu(alu_op, *alu_src_uops)
store = UOp.store(a.index(idx), alu)
sink = UOp(Ops.SINK, dtypes.void, (store,))

View File

@@ -2163,8 +2163,8 @@ class TestCopyFolding(unittest.TestCase):
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_on_disk_contiguous(self):
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
with open(temp('dt_arange_4_permute_contig'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute_contig')}")
b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])

View File

@@ -376,7 +376,7 @@ class TestUOpGraph(unittest.TestCase):
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),))
ld = d1.index(idx)
alu = (ld<1).cast(dtypes.bool)
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
uops = to_uops_list([out])
@@ -386,7 +386,7 @@ class TestUOpGraph(unittest.TestCase):
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),))
ld = d1.index(idx)
alu = ld.cast(dtypes.float).cast(dtypes.float)
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
uops = to_uops_list([out])
@@ -408,7 +408,7 @@ class TestUOpGraph(unittest.TestCase):
def test_bitcast_to_same_dtype_fold(self):
for dt in dtypes.ints + dtypes.floats + (dtypes.bool,):
d0 = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), arg=0)
v = UOp(Ops.LOAD, dt, (d0.index(UOp.const(dtypes.int, 0)),))
v = d0.index(UOp.const(dtypes.int, 0))
uops = to_uops_list([v.bitcast(dt)])
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}")
@@ -420,7 +420,7 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_fold(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
ld = d0.index(ridx0.valid(ridx0<50)).load()
ld = d0.index(ridx0.valid(ridx0<50))
w = (ridx0<50).where(ld, 5)
uops = to_uops_list([w])
for u in uops:
@@ -430,7 +430,7 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_folds_swapped_branches(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
ld = d0.index(ridx0.valid((ridx0<50).logical_not())).load()
ld = d0.index(ridx0.valid((ridx0<50).logical_not()))
w = (ridx0<50).where(5, ld)
uops = to_uops_list([w])
for u in uops:
@@ -441,7 +441,7 @@ class TestUOpGraph(unittest.TestCase):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_idx = ridx0.valid((ridx0<50))
ld = d0.index(gate_idx).load().cast(dtypes.float)
ld = d0.index(gate_idx).cast(dtypes.float)
w = (ridx0<50).where(ld, 5.0)
uops = to_uops_list([w])
for u in uops:
@@ -467,11 +467,11 @@ class TestUOpGraph(unittest.TestCase):
c1 = UOp.range(UOp.const(dtypes.index, 512), 1, AxisType.LOOP)
c2 = UOp.range(UOp.const(dtypes.index, 250), 2, AxisType.LOOP)
c3 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=())
c4 = c3.index(c1).load()
c4 = c3.index(c1)
c5 = UOp.range(UOp.const(dtypes.index, 240), 0, AxisType.REDUCE)
c6 = ((c2*UOp.const(dtypes.index, 240))+c5)
c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(60000), arg=2, src=())
c8 = c7.index(c6).load()
c8 = c7.index(c6)
c9 = ((c4<0).where((c4+60000), c4)!=c6.cast(dtypes.int)).where(0, c8.cast(dtypes.uint).cast(dtypes.uchar)).reduce(c5, arg=Ops.ADD)
c10 = c0.index(((c1*UOp.const(dtypes.index, 250))+c2)).store(c9).end(c1, c2)
uops = to_uops_list([c10])
@@ -481,25 +481,25 @@ class TestUOpGraph(unittest.TestCase):
def test_in_out_of_bounds_access(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0), ptr=True),))
to_uops_list([ld0])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15)),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15), ptr=True),))
to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7)),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7), ptr=True),))
to_uops_list([ld1])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_symbolic(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_gated_store(self):
@@ -531,7 +531,7 @@ class TestUOpGraph(unittest.TestCase):
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
# Load from local memory (after the IF/barrier)
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx), if_barrier))
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
# Store to global memory
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
@@ -542,18 +542,18 @@ class TestUOpGraph(unittest.TestCase):
ridx = UOp.range(20, 0)
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16))),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16)), ptr=True),))
to_uops_list([ld0])
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
i = (ldfloat+3.14).cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16)), ptr=True),))
def test_load_cast_to_bool(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not())),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not()), ptr=True),))
to_uops_list([ld0])
@unittest.skip("Bool load is not supported yet")
@@ -562,36 +562,36 @@ class TestUOpGraph(unittest.TestCase):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),)))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
to_uops_list([ld0])
def test_out_of_bounds_off_by_one_access(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_bounds_access_with_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16))),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16)), ptr=True),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16), ptr=True),))
to_uops_list([ld0, ld1])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_symbolic_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = Variable("i", 1, 80)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_index_load(self):
@@ -599,11 +599,11 @@ class TestUOpGraph(unittest.TestCase):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8)),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32))),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8), ptr=True),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32)), ptr=True),))
to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64))),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64)), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld1])
def test_bounds_with_loaded_bool(self):
@@ -611,8 +611,8 @@ class TestUOpGraph(unittest.TestCase):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
ld0 = glbl0.index(gidx0).load()
ld1 = glbl1.index(gidx0.valid(ld0)).load()
ld0 = glbl0.index(gidx0, ptr=True).load()
ld1 = glbl1.index(gidx0.valid(ld0), ptr=True).load()
with self.assertRaises(RuntimeError): to_uops_list([ld1])
def test_fold_gated_load(self):
@@ -620,38 +620,38 @@ class TestUOpGraph(unittest.TestCase):
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
idx = UOp.const(dtypes.int, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx.valid(UOp.const(dtypes.bool, True))),))
ld0 = glbl1.index(UOp.invalid())
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int))
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx, ptr=True), glbl0.index(lidx, ptr=True).load()))
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),))
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True))),))
ld0 = smem.after(barrier).index(UOp.invalid())
ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2))
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
def test_fold_gated_store(self):
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
idx0 = UOp.const(dtypes.int, 0)
idx1 = UOp.const(dtypes.int, 0)
val = UOp.const(dtypes.int, 42)
st0 = glbl.index(UOp.invalid()).store(val)
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True))).store(val)
st0 = glbl.index(UOp.invalid(), ptr=True).store(val)
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True)), ptr=True).store(val)
uops = to_uops_list([st0, st1])
# only the second store happens
self.assertEqual(len(uops), 5)
self.assertEqual(uops[-1], glbl.index(idx1).store(val))
self.assertEqual(uops[-1], glbl.index(idx1, ptr=True).store(val))
@unittest.skip("this is a uop type error")
def test_asserts_bad_gate(self):

View File

@@ -39,9 +39,9 @@ def _test_single_value(vals, op, dts):
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts))
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True)]) for i, dtype in enumerate(dts))
alu = uop(uops, op, output_dtype, loads)
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu))
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)]
prg = _uops_to_prg([out])
@@ -338,7 +338,7 @@ class TestLocalAccess(unittest.TestCase):
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True),))
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
# NOTE: webgpu specific, since only webgpu performs bitpacking
@@ -348,7 +348,7 @@ class TestLocalAccess(unittest.TestCase):
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
sres = smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0))
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
# NOTE: webgpu specific, since only webgpu performs bitpacking
@@ -382,7 +382,7 @@ class TestAssembly(unittest.TestCase):
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),))
l1 = g1.index(c1)
a1 = UOp(Ops.MUL, dtypes.int, (l1, c1))
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
uops = to_uops_list([a1,a2], ren=Device[Device.DEFAULT].renderer)
@@ -395,7 +395,7 @@ class TestAssembly(unittest.TestCase):
for dt in (dtypes.int32, dtypes.uint32):
g = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), (), 0)
c = UOp(Ops.CONST, dt, (), 2)
l = UOp(Ops.LOAD, dt, (g.index(c),))
l = g.index(c)
a = UOp(Ops.IDIV, dt, (l, c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render(uops)
@@ -406,7 +406,7 @@ class TestAssembly(unittest.TestCase):
def test_fast_idiv_and_mod(self):
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
c = UOp(Ops.CONST, dtypes.uint, (), 3)
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
l = g.index(c)
a = UOp(Ops.IDIV, dtypes.uint, (l, c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render(uops)
@@ -458,8 +458,7 @@ class TestAssembly(unittest.TestCase):
def test_use_cmpeq(self):
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
c = UOp(Ops.CONST, dtypes.uint, (), 7)
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
comp = l.ne(c).ne(True)
comp = g.index(c).ne(c).ne(True)
uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render(uops)
ops = [x.op for x in uops]

View File

@@ -141,8 +141,8 @@ class TestUOpsStats(unittest.TestCase):
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
u1 = globl.index(o1)
u2 = globl.index(o2)
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.MUL, dtypes.int, (u1,u2))
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
@@ -151,8 +151,8 @@ class TestUOpsStats(unittest.TestCase):
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
u1 = globl.index(o1)
u2 = globl.index(o2)
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3))
uops_fma = full_rewrite(u4.sink())

View File

@@ -9,13 +9,13 @@ from test.unit.test_uop_symbolic import check_uop_against_string
def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, (
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx.valid(valid)),
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx.valid(valid), ptr=True),
UOp.const(dtypes.float, 0.0)
))
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid)),
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid), ptr=True),
UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
))

View File

@@ -11,7 +11,7 @@ class TestTranscendentalFunctions(unittest.TestCase):
# TODO: Test constant input when constant folding is fixed (or maybe test both variants)
# Load input value from a buffer to prevent constant folding
input_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.double.ptr(), arg=1, src=())
loaded_value = UOp.load(input_buf.index(UOp.const(dtypes.int, 0)), dtype=dtypes.double)
loaded_value = input_buf.index(UOp.const(dtypes.int, 0))
def eval_payne_hanek_reduction(v:float) -> tuple[float, int]:
return tuple(eval_uop(u, [(dtypes.float64, [v])]) for u in payne_hanek_reduction(loaded_value))