mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix diamond assigns before mapping tensors UOps to assigns (#9855)
* keep tensor_map until diamond assign fixup * ctx
This commit is contained in:
@@ -34,7 +34,8 @@ class ProcessReplayWarning(Warning): pass
|
||||
# *** recreators
|
||||
|
||||
def recreate_sched(big_sink:UOp) -> list[UOp]:
|
||||
sched_sink = get_becomes_map(big_sink)[big_sink]
|
||||
becomes_map = get_becomes_map(big_sink)
|
||||
sched_sink = UOp.sink(*[becomes_map.get(x,x) for x in big_sink.src])
|
||||
return dedup(u.arg.ast for u in sched_sink.toposort if u.op is Ops.KERNEL)
|
||||
|
||||
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str, _) -> str:
|
||||
|
||||
@@ -1950,7 +1950,7 @@ class TestSwizzle(unittest.TestCase):
|
||||
y = x*x.sum((1,)).reciprocal()
|
||||
t = y.pad(((0,1),None)).contiguous()
|
||||
swizzled = swizzle_rewrite(t.lazydata)
|
||||
sched = check_schedule(swizzled.sink(), 3)
|
||||
sched = check_schedule(swizzled, 3)
|
||||
output_buffer = sched[-1].bufs[0]
|
||||
run_schedule(sched)
|
||||
self.assertListEqual(output_buffer.as_buffer().cast("f").tolist(), [0.5, 0.5, 0.5, 0.5, 0., 0.])
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, merge_views
|
||||
from tinygrad.ops import can_pad, sint, track_rewrites
|
||||
from tinygrad.ops import can_pad, sint, track_rewrites, _substitute
|
||||
from tinygrad.codegen.lowerer import get_contraction_with_reduce
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize, ContextVar, Context, diskcache_put
|
||||
@@ -440,15 +440,6 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
sched_sink = tensor_map[big_sink]
|
||||
type_verify(list(sched_sink.toposort), kernel_spec)
|
||||
|
||||
# map tensors to buffer/const, optionally apply a VIEW on top
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for k,v in tensor_map.items():
|
||||
if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
|
||||
if k is v: continue
|
||||
op = v.base.op
|
||||
if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v
|
||||
if op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
||||
|
||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||
kernel_assign: dict[UOp, UOp] = {}
|
||||
assign_rep: dict[UOp, UOp] = {}
|
||||
@@ -461,14 +452,23 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
|
||||
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
||||
if assign_rep:
|
||||
sched_sink = sched_sink.substitute(assign_rep)
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], _substitute, assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign")
|
||||
sched_sink = tensor_map[big_sink]
|
||||
type_verify(list(sched_sink.toposort), kernel_spec)
|
||||
becomes_map[big_sink] = sched_sink
|
||||
|
||||
# display the final graph
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph")
|
||||
|
||||
# map tensors to buffer/const/assign, optionally apply a VIEW on top
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for k,v in tensor_map.items():
|
||||
if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
|
||||
if k is v: continue
|
||||
op = v.base.op
|
||||
if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v
|
||||
if op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
||||
|
||||
# capture process replay
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
with Context(PICKLE_BUFFERS=0):
|
||||
|
||||
@@ -36,7 +36,7 @@ pm_unbind = PatternMatcher([
|
||||
|
||||
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
becomes_map = get_becomes_map(big_sink)
|
||||
sched_sink = becomes_map.pop(big_sink)
|
||||
sched_sink = UOp.sink(*[becomes_map.get(x,x) for x in big_sink.src])
|
||||
|
||||
# bfs toposort
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
|
||||
Reference in New Issue
Block a user