diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index ac9c1eb4d2..41c04a1915 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -77,7 +77,6 @@ class TestGraphRewrite(unittest.TestCase): self.assertEqual(nout.src[1].arg, 3.0) class TestUOpGraph(TestUOps): - # TODO: move to test.helpers def test_add_constant_fold(self): c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) @@ -122,15 +121,35 @@ class TestUOpGraph(TestUOps): def test_cast_vectorized_fold(self): d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True)) - idx = UOp(UOps.CONST, dtypes.int, arg=0) + idx = UOp.const(dtypes.int, 0) ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx)) cast = UOp(UOps.CAST, dtypes.float.vec(2), (ld,)) x = UOp(UOps.GEP, dtypes.float, (cast, ), arg=0) alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT) - out = UOp(UOps.STORE, dtypes.float, (d0, idx, alu)) + out = UOp(UOps.STORE, None, (d0, idx, alu)) g = UOpGraph([out]) self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 0) + def test_cast_alu_fold(self): + d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=(0, True)) + d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=(1, False)) + idx = UOp.const(dtypes.int, 0) + ld = UOp(UOps.LOAD, dtypes.int, (d1, idx)) + alu = ld.lt(1).cast(dtypes.bool) + out = UOp(UOps.STORE, None, (d0, idx, alu)) + g = UOpGraph([out]) + self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 0) + + def test_double_cast_fold(self): + d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True)) + d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=(1, False)) + idx = UOp.const(dtypes.int, 0) + ld = UOp(UOps.LOAD, dtypes.int, (d1, idx)) + alu = ld.cast(dtypes.float).cast(dtypes.float) + out = UOp(UOps.STORE, None, (d0, idx, alu)) + g = UOpGraph([out]) + self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 1) + def test_depth_2_const_fold(self): v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1)) c2 = UOp(UOps.CONST, dtypes.int, arg=2) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 7e9818021f..ebe3fc454e 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -91,7 +91,7 @@ class UOp: def type_verify(uops): for u in uops: uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype - if uop in (UOps.CONST, UOps.DEFINE_ACC): + if uop in {UOps.CONST, UOps.DEFINE_ACC}: if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}" arg = src[0].arg @@ -99,11 +99,13 @@ def type_verify(uops): if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg if uop is UOps.CAST and dtype is not None and dtype.count > 1: assert len(src) == dtype.count if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype - if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool + if uop is UOps.STORE: + assert dtype is None, f"{uop} dtype must be None, got {dtype}" + if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}" if uop is UOps.ALU: if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" - elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE): + elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}: assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}" assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" elif arg is BinaryOps.IDIV: