fix diamond assigns before mapping tensors UOps to assigns (#9855)

* keep tensor_map until diamond assign fixup

* ctx
This commit is contained in:
qazal
2025-04-18 14:17:43 +03:00
committed by GitHub
parent a37d921917
commit b58decac0c
4 changed files with 16 additions and 15 deletions

View File

@@ -34,7 +34,8 @@ class ProcessReplayWarning(Warning): pass
# *** recreators
def recreate_sched(big_sink:UOp) -> list[UOp]:
sched_sink = get_becomes_map(big_sink)[big_sink]
becomes_map = get_becomes_map(big_sink)
sched_sink = UOp.sink(*[becomes_map.get(x,x) for x in big_sink.src])
return dedup(u.arg.ast for u in sched_sink.toposort if u.op is Ops.KERNEL)
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str, _) -> str:

View File

@@ -1950,7 +1950,7 @@ class TestSwizzle(unittest.TestCase):
y = x*x.sum((1,)).reciprocal()
t = y.pad(((0,1),None)).contiguous()
swizzled = swizzle_rewrite(t.lazydata)
sched = check_schedule(swizzled.sink(), 3)
sched = check_schedule(swizzled, 3)
output_buffer = sched[-1].bufs[0]
run_schedule(sched)
self.assertListEqual(output_buffer.as_buffer().cast("f").tolist(), [0.5, 0.5, 0.5, 0.5, 0., 0.])

View File

@@ -1,7 +1,7 @@
from collections import defaultdict, deque
from dataclasses import dataclass
from tinygrad.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, merge_views
from tinygrad.ops import can_pad, sint, track_rewrites
from tinygrad.ops import can_pad, sint, track_rewrites, _substitute
from tinygrad.codegen.lowerer import get_contraction_with_reduce
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize, ContextVar, Context, diskcache_put
@@ -440,15 +440,6 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
sched_sink = tensor_map[big_sink]
type_verify(list(sched_sink.toposort), kernel_spec)
# map tensors to buffer/const, optionally apply a VIEW on top
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
if k is v: continue
op = v.base.op
if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v
if op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}
assign_rep: dict[UOp, UOp] = {}
@@ -461,14 +452,23 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep:
sched_sink = sched_sink.substitute(assign_rep)
tensor_map = graph_rewrite_map(tensor_map[big_sink], _substitute, assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign")
sched_sink = tensor_map[big_sink]
type_verify(list(sched_sink.toposort), kernel_spec)
becomes_map[big_sink] = sched_sink
# display the final graph
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph")
# map tensors to buffer/const/assign, optionally apply a VIEW on top
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
if k is v: continue
op = v.base.op
if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v
if op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
# capture process replay
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0):

View File

@@ -36,7 +36,7 @@ pm_unbind = PatternMatcher([
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
becomes_map = get_becomes_map(big_sink)
sched_sink = becomes_map.pop(big_sink)
sched_sink = UOp.sink(*[becomes_map.get(x,x) for x in big_sink.src])
# bfs toposort
children: dict[UOp, list[UOp]] = {}