diff --git a/test/test_lazybuffer.py b/test/test_lazybuffer.py index fbef800350..489f994769 100644 --- a/test/test_lazybuffer.py +++ b/test/test_lazybuffer.py @@ -2,6 +2,7 @@ import numpy as np import unittest from tinygrad.lazy import LazyBuffer +from tinygrad.tensor import Tensor class TestLazyBuffer(unittest.TestCase): def test_fromcpu_buffer_sharing(self): @@ -22,5 +23,25 @@ class TestLazyBuffer(unittest.TestCase): for start in [0, 1]: helper(a[(slice(start, None, stride),)*ndims]) + def test_shuffle_pad_ops_cmpeq(self): + y = Tensor([1]).cat(Tensor([1]).eq(0)).numpy() + z = Tensor([1, 0]).numpy() + np.testing.assert_allclose(y, z) + + def test_shuffle_pad_ops_div(self): + y = Tensor([1]).cat(Tensor([1]).div(Tensor([2.0]))).numpy() + z = Tensor([1, 0.5]).numpy() + np.testing.assert_allclose(y, z) + + def test_shuffle_pad_ops_log(self): + y = Tensor([1]).cat(Tensor([1]).log()).numpy() + z = Tensor([1, 0]).numpy() + np.testing.assert_allclose(y, z) + + def test_shuffle_pad_ops_exp(self): + y = Tensor([1]).cat(Tensor([1]).exp()).numpy() + z = Tensor([1, np.e]).numpy() + np.testing.assert_allclose(y, z) + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 735238449a..b10faf4eb2 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -283,8 +283,9 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]: assert isinstance(bx.op.op, MovementOps) mops.append((bx.op.op, bx.op.arg)) bx = cast(LazyBuffer, bx.op.src[0]) - # NOTE: can't push pads with a div - if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in bx.op.get_lazyops())): + # NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0 + unsafe_pad_ops = {BinaryOps.DIV, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP} + if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op not in unsafe_pad_ops for x in bx.op.get_lazyops())): new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1])) else: new_srcs.append(x)