From e03c0aacf2180469c98e198d0f4463931876bac4 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 17 Mar 2025 20:43:21 +0800 Subject: [PATCH] more explicit DONT_PUSH_VIEWS [pr] (#9479) * more explicit DONT_PUSH_VIEWS [pr] * update tests to not handcode ast * lint * test_recursive_swizzle and test_simple_store_reshape --- test/test_schedule.py | 43 +++++++++++++------------------------ tinygrad/engine/schedule.py | 4 ++-- 2 files changed, 17 insertions(+), 30 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index e09d8e3bd0..d4116362db 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -14,7 +14,7 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.spec import type_verify, shape_spec -from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, all_same, temp +from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, sym from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis @@ -1857,44 +1857,31 @@ class TestIndexing(unittest.TestCase): def test_recursive_swizzle(self): a = Tensor([1,2,3,4]).realize() for _ in range(24): a = a + a - ast = a.schedule()[0].ast - swizzle = ast.src[0].src[2].reshape((4, 1)) - new_uop = swizzle_rewrite(swizzle) + new_uop = swizzle_rewrite(a.lazydata.reshape((4, 1))) self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1))) self.assertEqual(swizzle_cnt(new_uop), 0) def test_no_rewrite_elementwise(self): - bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)] - ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop())) - sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),)) - rsink = graph_rewrite(sink, view_right) - self.assertEqual(rsink.key, sink.key) + a = Tensor.empty(32, 32) + b = Tensor.empty(32, 32) + sink = (a+b).schedule()[0].ast + self.assertEqual(swizzle_cnt(sink), 0) def test_simple_store_reshape(self): - bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] - ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1))) - r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) - r = r + r.const_like(2).replace(src=(unwrap(r.st).to_uop(),)) - sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) - rsink = graph_rewrite(sink, view_right) - # this AST first needs to swizzle, but it doesn't have implicit movementops - self.assertEqual(swizzle_cnt(sink), 1) - verify_ast(rsink) + a = Tensor.empty(32, 32).sum(axis=1)+Tensor.empty(1,32) + ast = a.schedule()[0].ast + self.assertEqual(ast.shape, (32, 1)) + self.assertEqual(a.lazydata.shape, (1, 32)) def test_no_reshape_reduceop(self): - bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] - ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1))) - sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),)) - rsink = graph_rewrite(sink, view_right) - verify_ast(sink) - self.assertEqual(sink.key, rsink.key) + a = Tensor.empty(32, 32).sum(axis=(1,)).contiguous() + ast = a.schedule()[0].ast + self.assertEqual(ast.shape, (32, 1)) + self.assertEqual(a.lazydata.shape, (32,)) @track_rewrites(named=True) def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right) -def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0]) +def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op is not Ops.BUFFER]) class TestSwizzle(unittest.TestCase): def test_swizzle_simple(self): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 4fa1222620..362f50a377 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -113,7 +113,7 @@ sym = symbolic_simple+PatternMatcher([ # **** UOp realization -DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS} +DONT_PUSH_VIEWS = {Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR, Ops.DEVICE, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS, Ops.COPY} @dataclass(frozen=True) class GrouperContext: @@ -139,7 +139,7 @@ do_realize = PatternMatcher([ # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize), # realize before expand or unsafe pad ops - (UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view), + (UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),)), realize_before_view), # realize before COPY (UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW}, name="tr"))), realize), ])