const fold left const operand for ADD and MUL (#4029)

* const fold left const operand for ADD and MUL

* neg have dtype issue
This commit is contained in:
chenyu
2024-04-01 15:09:04 -04:00
committed by GitHub
parent 0e02d074bd
commit 379d52548d
2 changed files with 35 additions and 8 deletions

View File

@@ -13,6 +13,10 @@ class TestSimpleConstFolding(unittest.TestCase):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + 0)
def test_add_tensor_zero(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(4))
def test_literal_zero_add(self):
_check_ast_count(0, 0 + Tensor([1.0, 2, 3, 4]))
def test_tensor_zero_add(self):
_check_ast_count(0, Tensor.zeros(4) + Tensor([1.0, 2, 3, 4]))
def test_sub_literal_zero(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) - 0)
@@ -23,11 +27,19 @@ class TestSimpleConstFolding(unittest.TestCase):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 0)
def test_mul_tensor_zero(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.zeros(4))
def test_literal_zero_mul(self):
_check_ast_count(0, 0 * Tensor([1.0, 2, 3, 4]) * 0)
def test_tensor_zero_mul(self):
_check_ast_count(0, Tensor.zeros(4) * Tensor([1.0, 2, 3, 4]))
def test_mul_literal_one(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 1)
def test_mul_tensor_one(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(4))
def test_literal_one_mul(self):
_check_ast_count(0, 1 * Tensor([1.0, 2, 3, 4]))
def test_tensor_one_mul(self):
_check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4]))
def test_div_literal_one(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / 1)
@@ -38,11 +50,24 @@ class TestSimpleConstFolding(unittest.TestCase):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 0)
def test_pow_tensor_zero(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.zeros(4))
# TODO: fix pow folding with left operand = 0 or 1
@unittest.expectedFailure
def test_literal_zero_pow(self):
_check_ast_count(0, 0 ** Tensor([1.0, 2, 3, 4]))
@unittest.expectedFailure
def test_tensor_zero_pow(self):
_check_ast_count(0, Tensor.zeros(4) ** Tensor([1.0, 2, 3, 4]))
def test_pow_literal_one(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 1)
def test_pow_tensor_one(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
@unittest.expectedFailure
def test_literal_one_pow(self):
_check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
@unittest.expectedFailure
def test_tensor_one_pow(self):
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
class TestMovedConstFolding(unittest.TestCase):
def test_add_shrunk_zero(self):

View File

@@ -132,14 +132,16 @@ class LazyBuffer:
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
# const folding
if op is BinaryOps.ADD and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg == 0: return self
if op is BinaryOps.SUB and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg == 0: return self
if op is BinaryOps.MUL and in_srcs[0].is_unrealized_unpadded_const():
if (val := in_srcs[0].base.arg) == 1: return self
if val == -1: return self.e(UnaryOps.NEG)
if val == 0: return self.const(0)
if op is BinaryOps.DIV and dtypes.is_float(self.dtype) and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg != 0:
return self.e(BinaryOps.MUL, self.const(1 / in_srcs[0].base.arg))
if op in BinaryOps: x, y = self, in_srcs[0]
if op is BinaryOps.ADD:
if y.is_unrealized_unpadded_const() and y.base.arg == 0: return x
if x.is_unrealized_unpadded_const() and x.base.arg == 0: return y
if op is BinaryOps.SUB and y.is_unrealized_unpadded_const() and y.base.arg == 0: return x
if op is BinaryOps.MUL:
if x.is_unrealized_unpadded_const() and (val := x.base.arg) in (1, 0): return {1: y, 0: y.const(0)}[val]
if y.is_unrealized_unpadded_const() and (val := y.base.arg) in (1, 0): return {1: x, 0: x.const(0)}[val]
if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unpadded_const() and y.base.arg != 0:
return x.e(BinaryOps.MUL, x.const(1 / y.base.arg))
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))