mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
const fold left const operand for ADD and MUL (#4029)
* const fold left const operand for ADD and MUL * neg have dtype issue
This commit is contained in:
@@ -13,6 +13,10 @@ class TestSimpleConstFolding(unittest.TestCase):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + 0)
|
||||
def test_add_tensor_zero(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(4))
|
||||
def test_literal_zero_add(self):
|
||||
_check_ast_count(0, 0 + Tensor([1.0, 2, 3, 4]))
|
||||
def test_tensor_zero_add(self):
|
||||
_check_ast_count(0, Tensor.zeros(4) + Tensor([1.0, 2, 3, 4]))
|
||||
|
||||
def test_sub_literal_zero(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) - 0)
|
||||
@@ -23,11 +27,19 @@ class TestSimpleConstFolding(unittest.TestCase):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 0)
|
||||
def test_mul_tensor_zero(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.zeros(4))
|
||||
def test_literal_zero_mul(self):
|
||||
_check_ast_count(0, 0 * Tensor([1.0, 2, 3, 4]) * 0)
|
||||
def test_tensor_zero_mul(self):
|
||||
_check_ast_count(0, Tensor.zeros(4) * Tensor([1.0, 2, 3, 4]))
|
||||
|
||||
def test_mul_literal_one(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 1)
|
||||
def test_mul_tensor_one(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(4))
|
||||
def test_literal_one_mul(self):
|
||||
_check_ast_count(0, 1 * Tensor([1.0, 2, 3, 4]))
|
||||
def test_tensor_one_mul(self):
|
||||
_check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4]))
|
||||
|
||||
def test_div_literal_one(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / 1)
|
||||
@@ -38,11 +50,24 @@ class TestSimpleConstFolding(unittest.TestCase):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 0)
|
||||
def test_pow_tensor_zero(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.zeros(4))
|
||||
# TODO: fix pow folding with left operand = 0 or 1
|
||||
@unittest.expectedFailure
|
||||
def test_literal_zero_pow(self):
|
||||
_check_ast_count(0, 0 ** Tensor([1.0, 2, 3, 4]))
|
||||
@unittest.expectedFailure
|
||||
def test_tensor_zero_pow(self):
|
||||
_check_ast_count(0, Tensor.zeros(4) ** Tensor([1.0, 2, 3, 4]))
|
||||
|
||||
def test_pow_literal_one(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 1)
|
||||
def test_pow_tensor_one(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
|
||||
@unittest.expectedFailure
|
||||
def test_literal_one_pow(self):
|
||||
_check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
|
||||
@unittest.expectedFailure
|
||||
def test_tensor_one_pow(self):
|
||||
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
|
||||
|
||||
class TestMovedConstFolding(unittest.TestCase):
|
||||
def test_add_shrunk_zero(self):
|
||||
|
||||
@@ -132,14 +132,16 @@ class LazyBuffer:
|
||||
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
|
||||
|
||||
# const folding
|
||||
if op is BinaryOps.ADD and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg == 0: return self
|
||||
if op is BinaryOps.SUB and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg == 0: return self
|
||||
if op is BinaryOps.MUL and in_srcs[0].is_unrealized_unpadded_const():
|
||||
if (val := in_srcs[0].base.arg) == 1: return self
|
||||
if val == -1: return self.e(UnaryOps.NEG)
|
||||
if val == 0: return self.const(0)
|
||||
if op is BinaryOps.DIV and dtypes.is_float(self.dtype) and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg != 0:
|
||||
return self.e(BinaryOps.MUL, self.const(1 / in_srcs[0].base.arg))
|
||||
if op in BinaryOps: x, y = self, in_srcs[0]
|
||||
if op is BinaryOps.ADD:
|
||||
if y.is_unrealized_unpadded_const() and y.base.arg == 0: return x
|
||||
if x.is_unrealized_unpadded_const() and x.base.arg == 0: return y
|
||||
if op is BinaryOps.SUB and y.is_unrealized_unpadded_const() and y.base.arg == 0: return x
|
||||
if op is BinaryOps.MUL:
|
||||
if x.is_unrealized_unpadded_const() and (val := x.base.arg) in (1, 0): return {1: y, 0: y.const(0)}[val]
|
||||
if y.is_unrealized_unpadded_const() and (val := y.base.arg) in (1, 0): return {1: x, 0: x.const(0)}[val]
|
||||
if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unpadded_const() and y.base.arg != 0:
|
||||
return x.e(BinaryOps.MUL, x.const(1 / y.base.arg))
|
||||
|
||||
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
||||
|
||||
Reference in New Issue
Block a user