check simple_pads in all views (#4614)

This commit is contained in:
qazal
2024-05-16 19:34:39 +08:00
committed by GitHub
parent 0b464df605
commit 13200c6894
2 changed files with 11 additions and 13 deletions

View File

@@ -799,22 +799,19 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 1))
np.testing.assert_equal(out.numpy(), [2, 0])
# TODO: should not shuffle unsafe pad ops through any pads, even if buffer is shrunk overall (#3437)
def test_shrink_pad_unsafe(self):
a = Tensor.ones((3, )).contiguous().realize()
out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 1))
with self.assertRaises(AssertionError):
np.testing.assert_equal(out.numpy(), [2, 0])
run_schedule(check_schedule(out, 2))
np.testing.assert_equal(out.numpy(), [2, 0])
def test_base_change_shrink_pad(self):
a = Tensor.ones(3, 3).contiguous().realize()
b = a.exp2()
c = b[:-1, :-1]
d = c.pad(((0, 1), (0, 1))) * 2
run_schedule(check_schedule(d, 1))
with self.assertRaises(AssertionError): # TODO unsafe pads
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:-1, :-1], ((0, 1), (0, 1)))*2)
run_schedule(check_schedule(d, 2))
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:-1, :-1], ((0, 1), (0, 1)))*2)
def test_base_change_expand_pad(self):
a = Tensor.ones(3, 3).contiguous().realize()

View File

@@ -128,12 +128,13 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
if GRAPH: log_lazybuffer(buf, scheduled)
# view
if buf.base != buf:
# realize all places where the buffer is expanded
if prod(buf.base.st.shape) < prod(buf.st.shape):
if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
simple_pads.add(buf.base)
elif buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
# fuse some pads
if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
simple_pads.add(buf.base)
# realize all expands
elif prod(buf.base.st.shape) < prod(buf.st.shape):
if buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
pass # don't realize image to image casts. this is part of a larger problem
else:
realizes[buf.base] = None