mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
check simple_pads in all views (#4614)
This commit is contained in:
@@ -799,22 +799,19 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_equal(out.numpy(), [2, 0])
|
||||
|
||||
# TODO: should not shuffle unsafe pad ops through any pads, even if buffer is shrunk overall (#3437)
|
||||
def test_shrink_pad_unsafe(self):
|
||||
a = Tensor.ones((3, )).contiguous().realize()
|
||||
out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous()
|
||||
run_schedule(check_schedule(out, 1))
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_equal(out.numpy(), [2, 0])
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_equal(out.numpy(), [2, 0])
|
||||
|
||||
def test_base_change_shrink_pad(self):
|
||||
a = Tensor.ones(3, 3).contiguous().realize()
|
||||
b = a.exp2()
|
||||
c = b[:-1, :-1]
|
||||
d = c.pad(((0, 1), (0, 1))) * 2
|
||||
run_schedule(check_schedule(d, 1))
|
||||
with self.assertRaises(AssertionError): # TODO unsafe pads
|
||||
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:-1, :-1], ((0, 1), (0, 1)))*2)
|
||||
run_schedule(check_schedule(d, 2))
|
||||
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:-1, :-1], ((0, 1), (0, 1)))*2)
|
||||
|
||||
def test_base_change_expand_pad(self):
|
||||
a = Tensor.ones(3, 3).contiguous().realize()
|
||||
|
||||
@@ -128,12 +128,13 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
|
||||
if GRAPH: log_lazybuffer(buf, scheduled)
|
||||
# view
|
||||
if buf.base != buf:
|
||||
# realize all places where the buffer is expanded
|
||||
if prod(buf.base.st.shape) < prod(buf.st.shape):
|
||||
if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
|
||||
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
|
||||
simple_pads.add(buf.base)
|
||||
elif buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
|
||||
# fuse some pads
|
||||
if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
|
||||
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
|
||||
simple_pads.add(buf.base)
|
||||
# realize all expands
|
||||
elif prod(buf.base.st.shape) < prod(buf.st.shape):
|
||||
if buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
|
||||
pass # don't realize image to image casts. this is part of a larger problem
|
||||
else:
|
||||
realizes[buf.base] = None
|
||||
|
||||
Reference in New Issue
Block a user