mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
lazy folding: mul -1 is neg, and neg neg is noop (#4472)
This commit is contained in:
@@ -23,6 +23,11 @@ class TestUnaryOpsConstFolding(unittest.TestCase):
|
||||
_check_ast_count(0, Tensor.ones(4).cast(dtypes.int16))
|
||||
_check_ast_count(0, Tensor.full(4, fill_value=-1).cast(dtypes.uint16))
|
||||
|
||||
def test_neg_folding(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3]).mul(-1).neg())
|
||||
_check_ast_count(0, Tensor([1, 2, 3]).neg().mul(-1))
|
||||
_check_ast_count(0, Tensor([1, 2, 3]).neg().neg())
|
||||
|
||||
class TestBinaryOpsConstFolding(unittest.TestCase):
|
||||
def test_add_literal_zero(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + 0)
|
||||
@@ -56,6 +61,13 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
|
||||
def test_tensor_one_mul(self):
|
||||
_check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4]))
|
||||
|
||||
def test_bool_tensor_mul_bool(self):
|
||||
_check_ast_count(0, Tensor([True, False]) * True)
|
||||
_check_ast_count(0, Tensor([True, False]) * False)
|
||||
def test_bool_mul_bool_tensor(self):
|
||||
_check_ast_count(0, True * Tensor([True, False]))
|
||||
_check_ast_count(0, False * Tensor([True, False]))
|
||||
|
||||
def test_div_literal_one(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / 1)
|
||||
def test_div_tensor_one(self):
|
||||
|
||||
@@ -149,14 +149,17 @@ class LazyBuffer:
|
||||
# const folding
|
||||
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
|
||||
return self.cast(out_dtype).const(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
|
||||
if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG: return self.base.srcs[0]
|
||||
if op in BinaryOps: x, y = self, in_srcs[0]
|
||||
if op is BinaryOps.ADD:
|
||||
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
|
||||
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
|
||||
if op is BinaryOps.SUB and y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
|
||||
if op is BinaryOps.MUL:
|
||||
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return {1: y, 0: y.const(0)}[val]
|
||||
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return {1: x, 0: x.const(0)}[val]
|
||||
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
|
||||
return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
|
||||
if y.is_unrealized_unmasked_const() and (val := float(y.base.arg)) in (1, 0, -1):
|
||||
return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
|
||||
if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unmasked_const() and y.base.arg != 0:
|
||||
return x.e(BinaryOps.MUL, x.const(1 / y.base.arg))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user