mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
keep VIEW in big_sink + copy of buffer view spec [pr] (#8727)
* keep views in sink [pr] * tests * things from the gpt2 bug
This commit is contained in:
@@ -14,7 +14,7 @@ from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
@@ -2269,6 +2269,54 @@ class TestCopyFolding(unittest.TestCase):
|
||||
a = Tensor.empty(4).lazydata
|
||||
check_schedule(a.clone(), 1, filter_sink=False)
|
||||
|
||||
# NOTE: moving copy before view might change this
|
||||
def test_shrink_copy(self):
|
||||
a = Tensor.arange(4)
|
||||
view = a.shrink(((0, 2),))
|
||||
b = view.clone()
|
||||
run_schedule(check_schedule(b, 2, filter_sink=False))
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 2)
|
||||
self.assertEqual(b.lazydata.size, 2)
|
||||
self.assertListEqual(b.tolist(), [0, 1])
|
||||
|
||||
def test_expanded_copy(self):
|
||||
a = Tensor.arange(2)
|
||||
view = a.reshape(2, 1).expand(2, 2)
|
||||
b = view.clone()
|
||||
run_schedule(check_schedule(b, 2, filter_sink=False))
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 2)
|
||||
self.assertEqual(b.lazydata.size, 4)
|
||||
self.assertListEqual(b.tolist(), [[0, 0], [1, 1]])
|
||||
|
||||
def test_permuted_copy(self):
|
||||
a = Tensor.arange(4)
|
||||
b = a.reshape(2, 2).permute(1, 0)
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
def test_permute_on_disk(self):
|
||||
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().lazydata.base.buffer.as_buffer())
|
||||
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
|
||||
b = a.reshape(2, 2).permute(1, 0).to("CLANG")
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
def test_permute_after_shrink(self):
|
||||
a = Tensor.arange(5)
|
||||
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG")
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
# NOTE: disk permute must come after COPY
|
||||
# TODO: this is wrong because of the permute
|
||||
@unittest.expectedFailure
|
||||
def test_permute_after_shrink_on_disk(self):
|
||||
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().lazydata.base.buffer.as_buffer())
|
||||
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
|
||||
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG")
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
class TestTensorUOpSpec(unittest.TestCase):
|
||||
def test_const_must_be_unmasked(self):
|
||||
a = Tensor.ones((4, 4)).pad((2, 2))
|
||||
@@ -2377,6 +2425,17 @@ class TestContiguous(unittest.TestCase):
|
||||
b = a.expand((4, 4)).contiguous().contiguous()
|
||||
check_schedule(b, 1)
|
||||
|
||||
def test_view_does_not_realize(self):
|
||||
a = Tensor.empty(4)
|
||||
b = a.expand((4, 4))
|
||||
check_schedule(b, 0)
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 4)
|
||||
|
||||
def test_contiguous_view_realizes(self):
|
||||
a = Tensor.empty(4)
|
||||
b = a.expand((4, 4)).contiguous()
|
||||
check_schedule(b, 1)
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 16)
|
||||
|
||||
class TestUOpBecome(unittest.TestCase):
|
||||
# the simplest case, if we create a new BUFFER for this UOp
|
||||
|
||||
@@ -388,10 +388,10 @@ sym = symbolic_simple+PatternMatcher([
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
||||
# remove CONST/BIND/BUFFER/VIEW from SINK
|
||||
# remove CONST/BIND/BUFFER from SINK
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
|
||||
if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
||||
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
||||
])
|
||||
|
||||
# ** this decides which ops get realized
|
||||
|
||||
Reference in New Issue
Block a user