From 44a67bf783ca96ea2816809a0f03e4a95b323f0f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 10 Mar 2024 14:47:24 -0700 Subject: [PATCH] constant folding (#3675) * constant fold * bool math * fix ptx --- test/test_linearizer.py | 5 ++--- test/test_uops.py | 10 ++++++++++ tinygrad/codegen/uops.py | 12 ++++++------ 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 4d4d295824..961869eee9 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -66,7 +66,7 @@ class TestLinearizer(unittest.TestCase): b_bufs = [u.uop for u in lin.uops.uops[-2].vin[1].vin] assert a_bufs == [UOps.LOAD, UOps.CONST] - assert b_bufs == [UOps.CONST, UOps.CONST] + assert b_bufs == [] # [UOps.CONST, UOps.CONST] will be folded def test_upcast_cse(self): # when upcasting, within a subtree, there may be common expressions. @@ -126,7 +126,6 @@ class TestLinearizer(unittest.TestCase): num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU]) assert num_ops == 0, "more alu uops than needed" - @unittest.skip("constant folding not supported yet") def test_constant_fold(self): a, b = Tensor(2), Tensor(3) r = a * b @@ -216,7 +215,7 @@ class TestLinearizer(unittest.TestCase): c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0) c1 = UOp(UOps.CONST, dtypes.float, vin=(), arg=1.0) assert helper_test_simplify(UOps.ALU, dtypes.float, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1), - arg=TernaryOps.WHERE).uop == UOps.CONST + arg=TernaryOps.WHERE).arg == c0.arg def helper_realized_ast(r:Tensor): s = create_schedule([r.lazydata]) diff --git a/test/test_uops.py b/test/test_uops.py index 9bdbcfa23c..19e448a5f2 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -146,6 +146,16 @@ class TestExecALU(TestUOps): self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, 3.0)), 2+(1.0/3.0)) self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, -3.0)), -2-(1.0/3.0)) + def test_bool_neg(self): + self.assertEqual(exec_alu(UnaryOps.NEG, dtypes.bool, (False,)), True) + self.assertEqual(exec_alu(UnaryOps.NEG, dtypes.bool, (True,)), False) + + def test_bool_cmplt(self): + self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, False)), False) + self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, True)), True) + self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, False)), False) + self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, True)), False) + def test_overflow(self): self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244) self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (256, 0)), 0) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 37bcbec818..5a7fbaf2a9 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -34,14 +34,16 @@ def hook_overflow(dv, fxn): python_alu = { UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))), - UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin, UnaryOps.NEG: operator.neg, + UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin, + UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt, BinaryOps.MOD: operator.mod, BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else math.nan), TernaryOps.WHERE: lambda x,y,z: y if x else z} truncate: Dict[DType, Callable] = { - **{dt:lambda x: x for dt in dtypes.fields().values() if dt == dtypes.bool or dtypes.is_float(dt)}, + dtypes.bool: lambda x: bool(x), + **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)}, **{dt:functools.partial(lambda vv,x: x&vv, (1 << (dt.itemsize*8))-1) for dt in dtypes.fields().values() if dtypes.is_unsigned(dt)}, **{dt:functools.partial(lambda vv,aa,x: ((x+aa)&vv)-aa, (1 << (dt.itemsize*8))-1, 1 << (dt.itemsize*8-1)) \ for dt in dtypes.fields().values() if dtypes.is_int(dt) and not dtypes.is_unsigned(dt)}} @@ -89,12 +91,10 @@ class UOpGraph: if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG: return self.add(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before) # constant folding - if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: - return self.add(UOps.CONST, dtype, arg=-vin[0].arg if dtype != dtypes.bool else not vin[0].arg, insert_before=insert_before) if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2] - if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype): - return self.add(UOps.CONST, dtype, arg=vin[0].arg * vin[1].arg, insert_before=insert_before) + if all(x.uop is UOps.CONST for x in vin): + return self.add(UOps.CONST, dtype, arg=exec_alu(arg, dtype, [x.arg for x in vin]), insert_before=insert_before) # zero folding for x in [0,1]: if arg is BinaryOps.ADD and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[1-x]