mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
uop graph tests + type_verify cleanup (#5292)
* test_cast_alu_fold * test_double_cast_fold + these should assert
This commit is contained in:
@@ -77,7 +77,6 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
class TestUOpGraph(TestUOps):
|
||||
# TODO: move to test.helpers
|
||||
def test_add_constant_fold(self):
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
@@ -122,15 +121,35 @@ class TestUOpGraph(TestUOps):
|
||||
|
||||
def test_cast_vectorized_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True))
|
||||
idx = UOp(UOps.CONST, dtypes.int, arg=0)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
|
||||
cast = UOp(UOps.CAST, dtypes.float.vec(2), (ld,))
|
||||
x = UOp(UOps.GEP, dtypes.float, (cast, ), arg=0)
|
||||
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
|
||||
out = UOp(UOps.STORE, dtypes.float, (d0, idx, alu))
|
||||
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 0)
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=(0, True))
|
||||
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=(1, False))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
||||
alu = ld.lt(1).cast(dtypes.bool)
|
||||
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 0)
|
||||
|
||||
def test_double_cast_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True))
|
||||
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=(1, False))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
||||
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
||||
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 1)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
|
||||
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
|
||||
Reference in New Issue
Block a user