diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 8d0754915e..11a5964a1a 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -78,17 +78,17 @@ class TestGraphRewrite(unittest.TestCase): self.assertEqual(nout.src[1].arg, 3.0) def test_consts_go_last(self): - a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('a', 0, 1)) - b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('b', 0, 1)) - c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('c', 0, 1)) - d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('d', 0, 1)) - outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2),a), arg=BinaryOps.ADD)] + a = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('a', 0, 1)) + b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1)) + c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1)) + d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1)) + outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2), a), arg=BinaryOps.ADD)] for out in outs: sink = graph_rewrite(out, constant_folder) print(sink) self.assertEqual(sink.op, UOps.ALU) self.assertEqual(sink.src[1].op, UOps.CONST) - self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 1) + self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3) class TestUOpGraph(TestUOps): def test_add_constant_fold(self): @@ -213,13 +213,13 @@ class TestUOpGraph(TestUOps): self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 1) def test_depth_2_const_fold(self): - v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1)) + v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('tmp', 0, 1)) c2 = UOp(UOps.CONST, dtypes.int, arg=2) c4 = UOp(UOps.CONST, dtypes.int, arg=4) vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD) out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD) g = UOpGraph([out]) - self.assertEqual(len(g.uops), 3) + self.assertEqual(len(g.uops), 5) out = g.uops[-1] self.assertEqual(out.op, UOps.ALU) self.assertEqual(out.arg, BinaryOps.ADD) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index da657eac31..77eee463d3 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -209,6 +209,7 @@ constant_folder = PatternMatcher([ # NOTE: this can be wrong for loaded NaN (NOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), (NOp.var('x') - NOp.var('x'), lambda x: x.const(0)), # x-x -> 0 + (UPat(op=UOps.ALU, name='x'), lambda x: x.const(x.vmin.arg) if x.op is not UOps.CONST and x.vmin.arg == x.vmax.arg else None), # lt folding (NOp.var('x').lt(NOp.var('y')), lambda x,y: NOp.const(dtypes.bool, True) if x.vmax.arg < y.vmin.arg else NOp.const(dtypes.bool, False) if x.vmin.arg >= y.vmax.arg else None),