diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 9f3c17faaf..4c57bf01b3 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -13,6 +13,10 @@ class TestSimpleConstFolding(unittest.TestCase): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) + 0) def test_add_tensor_zero(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(4)) + def test_literal_zero_add(self): + _check_ast_count(0, 0 + Tensor([1.0, 2, 3, 4])) + def test_tensor_zero_add(self): + _check_ast_count(0, Tensor.zeros(4) + Tensor([1.0, 2, 3, 4])) def test_sub_literal_zero(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) - 0) @@ -23,11 +27,19 @@ class TestSimpleConstFolding(unittest.TestCase): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 0) def test_mul_tensor_zero(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.zeros(4)) + def test_literal_zero_mul(self): + _check_ast_count(0, 0 * Tensor([1.0, 2, 3, 4]) * 0) + def test_tensor_zero_mul(self): + _check_ast_count(0, Tensor.zeros(4) * Tensor([1.0, 2, 3, 4])) def test_mul_literal_one(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 1) def test_mul_tensor_one(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(4)) + def test_literal_one_mul(self): + _check_ast_count(0, 1 * Tensor([1.0, 2, 3, 4])) + def test_tensor_one_mul(self): + _check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4])) def test_div_literal_one(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) / 1) @@ -38,11 +50,24 @@ class TestSimpleConstFolding(unittest.TestCase): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 0) def test_pow_tensor_zero(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.zeros(4)) + # TODO: fix pow folding with left operand = 0 or 1 + @unittest.expectedFailure + def test_literal_zero_pow(self): + _check_ast_count(0, 0 ** Tensor([1.0, 2, 3, 4])) + @unittest.expectedFailure + def test_tensor_zero_pow(self): + _check_ast_count(0, Tensor.zeros(4) ** Tensor([1.0, 2, 3, 4])) def test_pow_literal_one(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 1) def test_pow_tensor_one(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4)) + @unittest.expectedFailure + def test_literal_one_pow(self): + _check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4])) + @unittest.expectedFailure + def test_tensor_one_pow(self): + _check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4])) class TestMovedConstFolding(unittest.TestCase): def test_add_shrunk_zero(self): diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 211aeaf150..df3a138bb6 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -132,14 +132,16 @@ class LazyBuffer: if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool" # const folding - if op is BinaryOps.ADD and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg == 0: return self - if op is BinaryOps.SUB and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg == 0: return self - if op is BinaryOps.MUL and in_srcs[0].is_unrealized_unpadded_const(): - if (val := in_srcs[0].base.arg) == 1: return self - if val == -1: return self.e(UnaryOps.NEG) - if val == 0: return self.const(0) - if op is BinaryOps.DIV and dtypes.is_float(self.dtype) and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg != 0: - return self.e(BinaryOps.MUL, self.const(1 / in_srcs[0].base.arg)) + if op in BinaryOps: x, y = self, in_srcs[0] + if op is BinaryOps.ADD: + if y.is_unrealized_unpadded_const() and y.base.arg == 0: return x + if x.is_unrealized_unpadded_const() and x.base.arg == 0: return y + if op is BinaryOps.SUB and y.is_unrealized_unpadded_const() and y.base.arg == 0: return x + if op is BinaryOps.MUL: + if x.is_unrealized_unpadded_const() and (val := x.base.arg) in (1, 0): return {1: y, 0: y.const(0)}[val] + if y.is_unrealized_unpadded_const() and (val := y.base.arg) in (1, 0): return {1: x, 0: x.const(0)}[val] + if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unpadded_const() and y.base.arg != 0: + return x.e(BinaryOps.MUL, x.const(1 / y.base.arg)) out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))