From b3600e4774e45bc7796c5cfc0e8fa38c16bb8eb5 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 13 Mar 2026 22:42:35 -0400 Subject: [PATCH] don't emit assign in transform_precompiled_call [pr] (#15262) --- tinygrad/engine/allocations.py | 3 ++- tinygrad/schedule/indexing.py | 16 ++++++++++++---- tinygrad/schedule/multi.py | 5 +---- tinygrad/uop/spec.py | 3 +++ 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index 36307dd2e9..0ce15dfb60 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -95,7 +95,8 @@ def transform_precompiled_call(c:UOp) -> UOp|None: if c.src[0].op is Ops.SINK: return None out = _buffer_like(c) input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:]) - fxn = out.param_like(len(c.src)-1).assign(c.src[0]).sink() + target = out.param_like(len(c.src)-1).shrink_to(c.shape) + fxn = target.after(target.store(c.src[0])).sink() ret = out.after(c.replace(src=(fxn, *input_buffers, out), dtype=dtypes.void, tag=None)) # if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape if any(isinstance(s, UOp) for s in c.shape): ret = ret.shrink(tuple((0, s) for s in c.shape)) diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 22e54a94bc..7d9438c5af 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -27,7 +27,9 @@ def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp): pm_generate_realize_map = PatternMatcher([ # always realize - (UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN}, name="tr"), realize), + (UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.ASSIGN}, name="tr"), realize), + # realize AFTER of STORE+AFTER + (UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE)), allow_any_len=True, name="tr"), realize), # realize srcs of these (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs), # sometimes realize src of assign @@ -58,7 +60,13 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): new_srcs = [] for s in x.src: new_src = s - if s.op in {Ops.PARAM, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}: + # TODO: this STORE+AFTER is very explicit, AFTER is the one being realized, and STORE needs to end ranges + if x.op is Ops.AFTER and s.op is Ops.STORE and x in ctx.realize_map: + realized_ranges = ctx.realize_map[x] + assert isinstance(realized_ranges, list), "realize map must contain range list" + closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[x][1]) if i in realized_ranges]) + new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE]) + elif s.op in {Ops.PARAM, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}: if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) elif s in ctx.realize_map: realized_ranges = ctx.realize_map[s] @@ -163,8 +171,8 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: # no ranges on kernels, they are internal if x.op in {Ops.CALL, Ops.LINEAR}: continue - # no range on after - if x.op is Ops.AFTER: continue + # only STORE+AFTER has range + if x.op is Ops.AFTER and all(s.op is not Ops.STORE for s in x.src[1:]): continue # treat MSTACK/MSELECT like SINK if x.op in {Ops.MSTACK, Ops.MSELECT}: continue diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 2a97d75733..5d5413e8c3 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -141,13 +141,10 @@ multi_pm = PatternMatcher([ lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)), # rewrite into calls explicitly for MULTI (UPat(Ops.CALL, name="call"), rewrite_into_call), - (UPat(Ops.CALL, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi), + (UPat((Ops.CALL, Ops.AFTER, Ops.STORE), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi), # we just remove the MULTI from CALLs with dtypes.void and assume they are handled by the user for custom kernels (UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root: UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)), (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), - # after CALL - (UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CALL)), name="a"), - lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)), ])+replace_allreduce diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 62bcf056a1..839e4992d7 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -96,6 +96,9 @@ _tensor_spec = PatternMatcher([ # ASSIGN has a target and a value. It can also optionally depend on other assigns (UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])), + # STORE in tensor graph: store a value into a target + (UPat(Ops.STORE, dtypes.void, (UPat(), UPat())), lambda: True), + # MSELECT chooses one of the multi buffers (UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),