more explicit DONT_PUSH_VIEWS [pr] (#9479)

* more explicit DONT_PUSH_VIEWS [pr]

* update tests to not handcode ast

* lint

* test_recursive_swizzle and test_simple_store_reshape
This commit is contained in:
qazal
2025-03-17 20:43:21 +08:00
committed by GitHub
parent 3b00a778ba
commit e03c0aacf2
2 changed files with 17 additions and 30 deletions

View File

@@ -14,7 +14,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.spec import type_verify, shape_spec
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, all_same, temp
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, sym
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
@@ -1857,44 +1857,31 @@ class TestIndexing(unittest.TestCase):
def test_recursive_swizzle(self):
a = Tensor([1,2,3,4]).realize()
for _ in range(24): a = a + a
ast = a.schedule()[0].ast
swizzle = ast.src[0].src[2].reshape((4, 1))
new_uop = swizzle_rewrite(swizzle)
new_uop = swizzle_rewrite(a.lazydata.reshape((4, 1)))
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertEqual(swizzle_cnt(new_uop), 0)
def test_no_rewrite_elementwise(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
rsink = graph_rewrite(sink, view_right)
self.assertEqual(rsink.key, sink.key)
a = Tensor.empty(32, 32)
b = Tensor.empty(32, 32)
sink = (a+b).schedule()[0].ast
self.assertEqual(swizzle_cnt(sink), 0)
def test_simple_store_reshape(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
r = r + r.const_like(2).replace(src=(unwrap(r.st).to_uop(),))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
# this AST first needs to swizzle, but it doesn't have implicit movementops
self.assertEqual(swizzle_cnt(sink), 1)
verify_ast(rsink)
a = Tensor.empty(32, 32).sum(axis=1)+Tensor.empty(1,32)
ast = a.schedule()[0].ast
self.assertEqual(ast.shape, (32, 1))
self.assertEqual(a.lazydata.shape, (1, 32))
def test_no_reshape_reduceop(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
verify_ast(sink)
self.assertEqual(sink.key, rsink.key)
a = Tensor.empty(32, 32).sum(axis=(1,)).contiguous()
ast = a.schedule()[0].ast
self.assertEqual(ast.shape, (32, 1))
self.assertEqual(a.lazydata.shape, (32,))
@track_rewrites(named=True)
def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right)
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0])
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op is not Ops.BUFFER])
class TestSwizzle(unittest.TestCase):
def test_swizzle_simple(self):