diff --git a/test/test_schedule.py b/test/test_schedule.py index b0426630de..9d29bebb34 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -16,7 +16,7 @@ from tinygrad.shape.view import View from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same from tinygrad.codegen.kernel import verify_ast -from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, ops_folding +from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis @@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) @track_rewrites(named=True) -def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext()) +def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext()) class TestSchedule(unittest.TestCase): def test_basic_binop_fusion(self): @@ -1824,7 +1824,7 @@ def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops. # these pattern matchers should move to engine/schedule.py -sym = symbolic_simple+PatternMatcher([ +ops_folding = symbolic_simple+PatternMatcher([ (UPat(Ops.DETACH, name="x"), lambda x:x.src[0]), ]) @@ -1842,8 +1842,8 @@ def run_tensor_ast(r:Tensor): output = UOp.new_buffer(r.device, r.lazydata.size, r.dtype) glbl = UOp(Ops.DEFINE_GLOBAL, output.dtype.ptr(size=output.size), (), 0) sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink() - sink = graph_rewrite(sink, remove_movement_ops+sym+load_buffers+view_left, bufs:=[output]) - sink = graph_rewrite(sink, remove_movement_ops+sym+view_right) + sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output]) + sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right) si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ()) run_schedule([si]) return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist() @@ -2336,34 +2336,29 @@ class TestBufferUOp(unittest.TestCase): class TestContiguous(unittest.TestCase): def test_contiguous_buffer(self): - a = Tensor.empty(4).lazydata - b = a.alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - self.assertIs(b, a) + a = Tensor.empty(4) + b = a.contiguous() + check_schedule(b, 0) def test_contiguous_buffer_view(self): - a = Tensor.empty(4).lazydata - b = a.reshape((2, 2)).alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - self.assertIs(b, a.buf_uop.view(unwrap(b.st))) + a = Tensor.empty(4) + b = a.reshape((2, 2)).contiguous() + check_schedule(b, 0) def test_non_contiguous_buffer_view(self): - a = Tensor.empty(4, 1).lazydata - b = a.expand((4, 4)).alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {}) + a = Tensor.empty(4, 1) + b = a.expand((4, 4)).contiguous() + check_schedule(b, 1) def test_size_change_buffer_view(self): - a = Tensor.empty(4).lazydata - b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {}) + a = Tensor.empty(4) + b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).contiguous() + check_schedule(b, 1) def test_double_contiguous_realizes_once(self): - a = Tensor.empty(4, 1).lazydata - b = a.expand((4, 4)).alu(Ops.CONTIGUOUS).alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {}) + a = Tensor.empty(4, 1) + b = a.expand((4, 4)).contiguous().contiguous() + check_schedule(b, 1) if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d5acb3cbe3..497cb5970a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -242,7 +242,7 @@ if CAPTURE_PROCESS_REPLAY: def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER def uval(u:UOp) -> UOp: assert is_scheduled(u), f"must be a scheduled op {u}" - return r.src[0] if (r:=u.src[1]).op is Ops.CONTIGUOUS and not (r.src[0].base.op is Ops.VIEW and len(r.src[0].base.src) == 2) else r + return u.src[1] def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp], reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None: @@ -340,10 +340,6 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]: # **** Schedule creation and BFS toposort -class UPatScheduled(UPat): - def __init__(self, *args, **kwargs): - super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) - # ** this is schedule level const folding def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: @@ -366,8 +362,8 @@ def replace_contiguous(ctx:ScheduleContext, alu:UOp): if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src)) -ops_folding = symbolic_simple+PatternMatcher([ - # op with size 0 is zero +sym = symbolic_simple+PatternMatcher([ + # UOp with size 0 is zero (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), # DETACH is a NOOP here @@ -401,6 +397,10 @@ ops_folding = symbolic_simple+PatternMatcher([ # ** this decides which ops get realized +class UPatScheduled(UPat): + def __init__(self, *args, **kwargs): + super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) + def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None: @@ -494,7 +494,7 @@ remove_movement_ops = PatternMatcher([ @track_rewrites(named=True) def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec) - tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+ops_folding, ctx:=ScheduleContext()) + tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx:=ScheduleContext()) rev_tensor_map: dict[UOp, list[UOp]] = {} for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) # add BUFFER uops