fix more tests

This commit is contained in:
George Hotz
2025-10-30 10:32:22 +08:00
parent adc15c7497
commit 89d8b79196

View File

@@ -348,7 +348,7 @@ class TestLocalAccess(unittest.TestCase):
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
sres = smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0))
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
# NOTE: webgpu specific, since only webgpu performs bitpacking
@@ -382,7 +382,7 @@ class TestAssembly(unittest.TestCase):
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),))
l1 = g1.index(c1)
a1 = UOp(Ops.MUL, dtypes.int, (l1, c1))
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
uops = to_uops_list([a1,a2], ren=Device[Device.DEFAULT].renderer)
@@ -395,7 +395,7 @@ class TestAssembly(unittest.TestCase):
for dt in (dtypes.int32, dtypes.uint32):
g = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), (), 0)
c = UOp(Ops.CONST, dt, (), 2)
l = UOp(Ops.LOAD, dt, (g.index(c),))
l = g.index(c)
a = UOp(Ops.IDIV, dt, (l, c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render(uops)
@@ -406,7 +406,7 @@ class TestAssembly(unittest.TestCase):
def test_fast_idiv_and_mod(self):
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
c = UOp(Ops.CONST, dtypes.uint, (), 3)
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
l = g.index(c)
a = UOp(Ops.IDIV, dtypes.uint, (l, c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render(uops)
@@ -458,8 +458,7 @@ class TestAssembly(unittest.TestCase):
def test_use_cmpeq(self):
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
c = UOp(Ops.CONST, dtypes.uint, (), 7)
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
comp = l.ne(c).ne(True)
comp = g.index(c).ne(c).ne(True)
uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render(uops)
ops = [x.op for x in uops]