pad fusion tests (#4570)

* what breaks

* Revert "what breaks"

This reverts commit e79f679283.

* simplest case

* one unsafe op

* expand+pad, shrink+pad

* safe case

* refactor
This commit is contained in:
qazal
2024-05-15 01:34:46 +08:00
committed by GitHub
parent 7afca52796
commit 355e1c135c

View File

@@ -3,7 +3,9 @@
# NOTE: this has overlap with external_test_opt.py # NOTE: this has overlap with external_test_opt.py
import unittest import unittest
import numpy as np
from typing import List, Optional, Union from typing import List, Optional, Union
from tinygrad.engine.realize import run_schedule
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
from tinygrad.helpers import DEBUG, flatten from tinygrad.helpers import DEBUG, flatten
@@ -775,5 +777,35 @@ class TestSchedule(unittest.TestCase):
f = (b - d).sum() - e f = (b - d).sum() - e
check_schedule([c, d, e, f], 3) check_schedule([c, d, e, f], 3)
def test_pad_reduce_safe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
b = Tensor.rand(3, 4, 5).realize()
out = (a + b).pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())
def test_pad_reduce_usafe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())
def test_shrink_pad_safe(self):
a = Tensor.ones((3, )).contiguous().realize()
b = Tensor.ones((3, )).contiguous().realize()
out = (a + b).shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_equal(out.numpy(), [2, 0])
# TODO: should not shuffle unsafe pad ops through any pads, even if buffer is shrunk overall (#3437)
def test_shrink_pad_unsafe(self):
a = Tensor.ones((3, )).contiguous().realize()
out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 1))
with self.assertRaises(AssertionError):
np.testing.assert_equal(out.numpy(), [2, 0])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main(verbosity=2) unittest.main(verbosity=2)