diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 933dfc4a29..fac2407abb 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -34,7 +34,8 @@ class ProcessReplayWarning(Warning): pass # *** recreators def recreate_sched(big_sink:UOp) -> list[UOp]: - sched_sink = get_becomes_map(big_sink)[big_sink] + becomes_map = get_becomes_map(big_sink) + sched_sink = UOp.sink(*[becomes_map.get(x,x) for x in big_sink.src]) return dedup(u.arg.ast for u in sched_sink.toposort if u.op is Ops.KERNEL) def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str, _) -> str: diff --git a/test/test_schedule.py b/test/test_schedule.py index a5abc8354b..ecf8dfb7cc 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1950,7 +1950,7 @@ class TestSwizzle(unittest.TestCase): y = x*x.sum((1,)).reciprocal() t = y.pad(((0,1),None)).contiguous() swizzled = swizzle_rewrite(t.lazydata) - sched = check_schedule(swizzled.sink(), 3) + sched = check_schedule(swizzled, 3) output_buffer = sched[-1].bufs[0] run_schedule(sched) self.assertListEqual(output_buffer.as_buffer().cast("f").tolist(), [0.5, 0.5, 0.5, 0.5, 0., 0.]) diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index cbe4273c0c..56a14e69ac 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -1,7 +1,7 @@ from collections import defaultdict, deque from dataclasses import dataclass from tinygrad.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, merge_views -from tinygrad.ops import can_pad, sint, track_rewrites +from tinygrad.ops import can_pad, sint, track_rewrites, _substitute from tinygrad.codegen.lowerer import get_contraction_with_reduce from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize, ContextVar, Context, diskcache_put @@ -440,15 +440,6 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]: sched_sink = tensor_map[big_sink] type_verify(list(sched_sink.toposort), kernel_spec) - # map tensors to buffer/const, optionally apply a VIEW on top - becomes_map: dict[UOp, UOp] = {} - for k,v in tensor_map.items(): - if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st)) - if k is v: continue - op = v.base.op - if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v - if op is Ops.CONST and all_int(v.shape): becomes_map[k] = v - # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign kernel_assign: dict[UOp, UOp] = {} assign_rep: dict[UOp, UOp] = {} @@ -461,14 +452,23 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]: raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER") assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,)) if assign_rep: - sched_sink = sched_sink.substitute(assign_rep) + tensor_map = graph_rewrite_map(tensor_map[big_sink], _substitute, assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign") + sched_sink = tensor_map[big_sink] type_verify(list(sched_sink.toposort), kernel_spec) - becomes_map[big_sink] = sched_sink # display the final graph if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph") if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph") + # map tensors to buffer/const/assign, optionally apply a VIEW on top + becomes_map: dict[UOp, UOp] = {} + for k,v in tensor_map.items(): + if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st)) + if k is v: continue + op = v.base.op + if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v + if op is Ops.CONST and all_int(v.shape): becomes_map[k] = v + # capture process replay if CAPTURE_PROCESS_REPLAY: with Context(PICKLE_BUFFERS=0): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ee59ac2bae..c3a1fd8157 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -36,7 +36,7 @@ pm_unbind = PatternMatcher([ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: becomes_map = get_becomes_map(big_sink) - sched_sink = becomes_map.pop(big_sink) + sched_sink = UOp.sink(*[becomes_map.get(x,x) for x in big_sink.src]) # bfs toposort children: dict[UOp, list[UOp]] = {}