mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
Fix LazyBuffer SHUFFLE_PAD_OPS to prevent invalid pad movement (#1223)
In addition to div, any ops that will generate non-zero outputs from zero inputs need to be guarded.
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class TestLazyBuffer(unittest.TestCase):
|
||||
def test_fromcpu_buffer_sharing(self):
|
||||
@@ -22,5 +23,25 @@ class TestLazyBuffer(unittest.TestCase):
|
||||
for start in [0, 1]:
|
||||
helper(a[(slice(start, None, stride),)*ndims])
|
||||
|
||||
def test_shuffle_pad_ops_cmpeq(self):
|
||||
y = Tensor([1]).cat(Tensor([1]).eq(0)).numpy()
|
||||
z = Tensor([1, 0]).numpy()
|
||||
np.testing.assert_allclose(y, z)
|
||||
|
||||
def test_shuffle_pad_ops_div(self):
|
||||
y = Tensor([1]).cat(Tensor([1]).div(Tensor([2.0]))).numpy()
|
||||
z = Tensor([1, 0.5]).numpy()
|
||||
np.testing.assert_allclose(y, z)
|
||||
|
||||
def test_shuffle_pad_ops_log(self):
|
||||
y = Tensor([1]).cat(Tensor([1]).log()).numpy()
|
||||
z = Tensor([1, 0]).numpy()
|
||||
np.testing.assert_allclose(y, z)
|
||||
|
||||
def test_shuffle_pad_ops_exp(self):
|
||||
y = Tensor([1]).cat(Tensor([1]).exp()).numpy()
|
||||
z = Tensor([1, np.e]).numpy()
|
||||
np.testing.assert_allclose(y, z)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -283,8 +283,9 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
||||
assert isinstance(bx.op.op, MovementOps)
|
||||
mops.append((bx.op.op, bx.op.arg))
|
||||
bx = cast(LazyBuffer, bx.op.src[0])
|
||||
# NOTE: can't push pads with a div
|
||||
if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in bx.op.get_lazyops())):
|
||||
# NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0
|
||||
unsafe_pad_ops = {BinaryOps.DIV, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
|
||||
if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op not in unsafe_pad_ops for x in bx.op.get_lazyops())):
|
||||
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
|
||||
else:
|
||||
new_srcs.append(x)
|
||||
|
||||
Reference in New Issue
Block a user