mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
more explicit DONT_PUSH_VIEWS [pr] (#9479)
* more explicit DONT_PUSH_VIEWS [pr] * update tests to not handcode ast * lint * test_recursive_swizzle and test_simple_store_reshape
This commit is contained in:
@@ -14,7 +14,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.spec import type_verify, shape_spec
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, all_same, temp
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, sym
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
@@ -1857,44 +1857,31 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_recursive_swizzle(self):
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
for _ in range(24): a = a + a
|
||||
ast = a.schedule()[0].ast
|
||||
swizzle = ast.src[0].src[2].reshape((4, 1))
|
||||
new_uop = swizzle_rewrite(swizzle)
|
||||
new_uop = swizzle_rewrite(a.lazydata.reshape((4, 1)))
|
||||
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
|
||||
self.assertEqual(swizzle_cnt(new_uop), 0)
|
||||
|
||||
def test_no_rewrite_elementwise(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
self.assertEqual(rsink.key, sink.key)
|
||||
a = Tensor.empty(32, 32)
|
||||
b = Tensor.empty(32, 32)
|
||||
sink = (a+b).schedule()[0].ast
|
||||
self.assertEqual(swizzle_cnt(sink), 0)
|
||||
|
||||
def test_simple_store_reshape(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
|
||||
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
r = r + r.const_like(2).replace(src=(unwrap(r.st).to_uop(),))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
# this AST first needs to swizzle, but it doesn't have implicit movementops
|
||||
self.assertEqual(swizzle_cnt(sink), 1)
|
||||
verify_ast(rsink)
|
||||
a = Tensor.empty(32, 32).sum(axis=1)+Tensor.empty(1,32)
|
||||
ast = a.schedule()[0].ast
|
||||
self.assertEqual(ast.shape, (32, 1))
|
||||
self.assertEqual(a.lazydata.shape, (1, 32))
|
||||
|
||||
def test_no_reshape_reduceop(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
verify_ast(sink)
|
||||
self.assertEqual(sink.key, rsink.key)
|
||||
a = Tensor.empty(32, 32).sum(axis=(1,)).contiguous()
|
||||
ast = a.schedule()[0].ast
|
||||
self.assertEqual(ast.shape, (32, 1))
|
||||
self.assertEqual(a.lazydata.shape, (32,))
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right)
|
||||
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0])
|
||||
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op is not Ops.BUFFER])
|
||||
|
||||
class TestSwizzle(unittest.TestCase):
|
||||
def test_swizzle_simple(self):
|
||||
|
||||
Reference in New Issue
Block a user