mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix more tests
This commit is contained in:
@@ -348,7 +348,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0)),))
|
||||
sres = smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0))
|
||||
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
|
||||
|
||||
# NOTE: webgpu specific, since only webgpu performs bitpacking
|
||||
@@ -382,7 +382,7 @@ class TestAssembly(unittest.TestCase):
|
||||
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
|
||||
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
|
||||
l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),))
|
||||
l1 = g1.index(c1)
|
||||
a1 = UOp(Ops.MUL, dtypes.int, (l1, c1))
|
||||
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
|
||||
uops = to_uops_list([a1,a2], ren=Device[Device.DEFAULT].renderer)
|
||||
@@ -395,7 +395,7 @@ class TestAssembly(unittest.TestCase):
|
||||
for dt in (dtypes.int32, dtypes.uint32):
|
||||
g = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), (), 0)
|
||||
c = UOp(Ops.CONST, dt, (), 2)
|
||||
l = UOp(Ops.LOAD, dt, (g.index(c),))
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.IDIV, dt, (l, c))
|
||||
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render(uops)
|
||||
@@ -406,7 +406,7 @@ class TestAssembly(unittest.TestCase):
|
||||
def test_fast_idiv_and_mod(self):
|
||||
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
|
||||
c = UOp(Ops.CONST, dtypes.uint, (), 3)
|
||||
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
|
||||
l = g.index(c)
|
||||
a = UOp(Ops.IDIV, dtypes.uint, (l, c))
|
||||
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render(uops)
|
||||
@@ -458,8 +458,7 @@ class TestAssembly(unittest.TestCase):
|
||||
def test_use_cmpeq(self):
|
||||
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
|
||||
c = UOp(Ops.CONST, dtypes.uint, (), 7)
|
||||
l = UOp(Ops.LOAD, dtypes.uint, (g.index(c),))
|
||||
comp = l.ne(c).ne(True)
|
||||
comp = g.index(c).ne(c).ne(True)
|
||||
uops = to_uops_list([comp], ren=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render(uops)
|
||||
ops = [x.op for x in uops]
|
||||
|
||||
Reference in New Issue
Block a user