Fix LazyBuffer SHUFFLE_PAD_OPS to prevent invalid pad movement (#1223)

In addition to div, any ops that will generate non-zero outputs from
zero inputs need to be guarded.
This commit is contained in:
Francis Lam
2023-07-11 15:30:35 -07:00
committed by GitHub
parent f75de602df
commit df86672bd4
2 changed files with 24 additions and 2 deletions

View File

@@ -2,6 +2,7 @@
import numpy as np
import unittest
from tinygrad.lazy import LazyBuffer
from tinygrad.tensor import Tensor
class TestLazyBuffer(unittest.TestCase):
def test_fromcpu_buffer_sharing(self):
@@ -22,5 +23,25 @@ class TestLazyBuffer(unittest.TestCase):
for start in [0, 1]:
helper(a[(slice(start, None, stride),)*ndims])
def test_shuffle_pad_ops_cmpeq(self):
y = Tensor([1]).cat(Tensor([1]).eq(0)).numpy()
z = Tensor([1, 0]).numpy()
np.testing.assert_allclose(y, z)
def test_shuffle_pad_ops_div(self):
y = Tensor([1]).cat(Tensor([1]).div(Tensor([2.0]))).numpy()
z = Tensor([1, 0.5]).numpy()
np.testing.assert_allclose(y, z)
def test_shuffle_pad_ops_log(self):
y = Tensor([1]).cat(Tensor([1]).log()).numpy()
z = Tensor([1, 0]).numpy()
np.testing.assert_allclose(y, z)
def test_shuffle_pad_ops_exp(self):
y = Tensor([1]).cat(Tensor([1]).exp()).numpy()
z = Tensor([1, np.e]).numpy()
np.testing.assert_allclose(y, z)
if __name__ == "__main__":
unittest.main()

View File

@@ -283,8 +283,9 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
assert isinstance(bx.op.op, MovementOps)
mops.append((bx.op.op, bx.op.arg))
bx = cast(LazyBuffer, bx.op.src[0])
# NOTE: can't push pads with a div
if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in bx.op.get_lazyops())):
# NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0
unsafe_pad_ops = {BinaryOps.DIV, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op not in unsafe_pad_ops for x in bx.op.get_lazyops())):
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
else:
new_srcs.append(x)