diff --git a/test/test_schedule.py b/test/test_schedule.py index 5c9564966f..e3682cd1c3 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -14,7 +14,7 @@ from tinygrad.dtype import DType, ImageDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views -from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same +from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp from tinygrad.codegen.kernel import verify_ast from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule @@ -2269,6 +2269,54 @@ class TestCopyFolding(unittest.TestCase): a = Tensor.empty(4).lazydata check_schedule(a.clone(), 1, filter_sink=False) + # NOTE: moving copy before view might change this + def test_shrink_copy(self): + a = Tensor.arange(4) + view = a.shrink(((0, 2),)) + b = view.clone() + run_schedule(check_schedule(b, 2, filter_sink=False)) + self.assertEqual(b.lazydata.base.buffer.size, 2) + self.assertEqual(b.lazydata.size, 2) + self.assertListEqual(b.tolist(), [0, 1]) + + def test_expanded_copy(self): + a = Tensor.arange(2) + view = a.reshape(2, 1).expand(2, 2) + b = view.clone() + run_schedule(check_schedule(b, 2, filter_sink=False)) + self.assertEqual(b.lazydata.base.buffer.size, 2) + self.assertEqual(b.lazydata.size, 4) + self.assertListEqual(b.tolist(), [[0, 0], [1, 1]]) + + def test_permuted_copy(self): + a = Tensor.arange(4) + b = a.reshape(2, 2).permute(1, 0) + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + + def test_permute_on_disk(self): + with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().lazydata.base.buffer.as_buffer()) + a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}") + b = a.reshape(2, 2).permute(1, 0).to("CLANG") + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + + def test_permute_after_shrink(self): + a = Tensor.arange(5) + b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG") + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + + # NOTE: disk permute must come after COPY + # TODO: this is wrong because of the permute + @unittest.expectedFailure + def test_permute_after_shrink_on_disk(self): + with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().lazydata.base.buffer.as_buffer()) + a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}") + b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG") + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + class TestTensorUOpSpec(unittest.TestCase): def test_const_must_be_unmasked(self): a = Tensor.ones((4, 4)).pad((2, 2)) @@ -2377,6 +2425,17 @@ class TestContiguous(unittest.TestCase): b = a.expand((4, 4)).contiguous().contiguous() check_schedule(b, 1) + def test_view_does_not_realize(self): + a = Tensor.empty(4) + b = a.expand((4, 4)) + check_schedule(b, 0) + self.assertEqual(b.lazydata.base.buffer.size, 4) + + def test_contiguous_view_realizes(self): + a = Tensor.empty(4) + b = a.expand((4, 4)).contiguous() + check_schedule(b, 1) + self.assertEqual(b.lazydata.base.buffer.size, 16) class TestUOpBecome(unittest.TestCase): # the simplest case, if we create a new BUFFER for this UOp diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 07cbb3b761..d8236d47bf 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -388,10 +388,10 @@ sym = symbolic_simple+PatternMatcher([ # support for using a contiguous permuted view instead of the parent view if one exists (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), (UPat(GroupOp.ALU, name="alu"), replace_contiguous), - # remove CONST/BIND/BUFFER/VIEW from SINK + # remove CONST/BIND/BUFFER from SINK (UPat(Ops.SINK, name="root"), lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg) - if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), + if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), ]) # ** this decides which ops get realized