keep VIEW in big_sink + copy of buffer view spec [pr] (#8727)

* keep views in sink [pr]

* tests

* things from the gpt2 bug
This commit is contained in:
qazal
2025-01-23 04:29:30 -05:00
committed by GitHub
parent 6cb74bb630
commit 07ec99001a
2 changed files with 62 additions and 3 deletions

View File

@@ -14,7 +14,7 @@ from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
from tinygrad.codegen.kernel import verify_ast
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
@@ -2269,6 +2269,54 @@ class TestCopyFolding(unittest.TestCase):
a = Tensor.empty(4).lazydata
check_schedule(a.clone(), 1, filter_sink=False)
# NOTE: moving copy before view might change this
def test_shrink_copy(self):
a = Tensor.arange(4)
view = a.shrink(((0, 2),))
b = view.clone()
run_schedule(check_schedule(b, 2, filter_sink=False))
self.assertEqual(b.lazydata.base.buffer.size, 2)
self.assertEqual(b.lazydata.size, 2)
self.assertListEqual(b.tolist(), [0, 1])
def test_expanded_copy(self):
a = Tensor.arange(2)
view = a.reshape(2, 1).expand(2, 2)
b = view.clone()
run_schedule(check_schedule(b, 2, filter_sink=False))
self.assertEqual(b.lazydata.base.buffer.size, 2)
self.assertEqual(b.lazydata.size, 4)
self.assertListEqual(b.tolist(), [[0, 0], [1, 1]])
def test_permuted_copy(self):
a = Tensor.arange(4)
b = a.reshape(2, 2).permute(1, 0)
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_on_disk(self):
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().lazydata.base.buffer.as_buffer())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
b = a.reshape(2, 2).permute(1, 0).to("CLANG")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_after_shrink(self):
a = Tensor.arange(5)
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
# NOTE: disk permute must come after COPY
# TODO: this is wrong because of the permute
@unittest.expectedFailure
def test_permute_after_shrink_on_disk(self):
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().lazydata.base.buffer.as_buffer())
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
class TestTensorUOpSpec(unittest.TestCase):
def test_const_must_be_unmasked(self):
a = Tensor.ones((4, 4)).pad((2, 2))
@@ -2377,6 +2425,17 @@ class TestContiguous(unittest.TestCase):
b = a.expand((4, 4)).contiguous().contiguous()
check_schedule(b, 1)
def test_view_does_not_realize(self):
a = Tensor.empty(4)
b = a.expand((4, 4))
check_schedule(b, 0)
self.assertEqual(b.lazydata.base.buffer.size, 4)
def test_contiguous_view_realizes(self):
a = Tensor.empty(4)
b = a.expand((4, 4)).contiguous()
check_schedule(b, 1)
self.assertEqual(b.lazydata.base.buffer.size, 16)
class TestUOpBecome(unittest.TestCase):
# the simplest case, if we create a new BUFFER for this UOp

View File

@@ -388,10 +388,10 @@ sym = symbolic_simple+PatternMatcher([
# support for using a contiguous permuted view instead of the parent view if one exists
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
# remove CONST/BIND/BUFFER/VIEW from SINK
# remove CONST/BIND/BUFFER from SINK
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
])
# ** this decides which ops get realized