don't emit assign in transform_precompiled_call [pr] (#15262)

This commit is contained in:
chenyu
2026-03-13 22:42:35 -04:00
committed by GitHub
parent 4d60312f7f
commit b3600e4774
4 changed files with 18 additions and 9 deletions

View File

@@ -95,7 +95,8 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
if c.src[0].op is Ops.SINK: return None
out = _buffer_like(c)
input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:])
fxn = out.param_like(len(c.src)-1).assign(c.src[0]).sink()
target = out.param_like(len(c.src)-1).shrink_to(c.shape)
fxn = target.after(target.store(c.src[0])).sink()
ret = out.after(c.replace(src=(fxn, *input_buffers, out), dtype=dtypes.void, tag=None))
# if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape
if any(isinstance(s, UOp) for s in c.shape): ret = ret.shrink(tuple((0, s) for s in c.shape))

View File

@@ -27,7 +27,9 @@ def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
pm_generate_realize_map = PatternMatcher([
# always realize
(UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN}, name="tr"), realize),
(UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.ASSIGN}, name="tr"), realize),
# realize AFTER of STORE+AFTER
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE)), allow_any_len=True, name="tr"), realize),
# realize srcs of these
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
# sometimes realize src of assign
@@ -58,7 +60,13 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
new_srcs = []
for s in x.src:
new_src = s
if s.op in {Ops.PARAM, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
# TODO: this STORE+AFTER is very explicit, AFTER is the one being realized, and STORE needs to end ranges
if x.op is Ops.AFTER and s.op is Ops.STORE and x in ctx.realize_map:
realized_ranges = ctx.realize_map[x]
assert isinstance(realized_ranges, list), "realize map must contain range list"
closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[x][1]) if i in realized_ranges])
new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
elif s.op in {Ops.PARAM, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
elif s in ctx.realize_map:
realized_ranges = ctx.realize_map[s]
@@ -163,8 +171,8 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# no ranges on kernels, they are internal
if x.op in {Ops.CALL, Ops.LINEAR}: continue
# no range on after
if x.op is Ops.AFTER: continue
# only STORE+AFTER has range
if x.op is Ops.AFTER and all(s.op is not Ops.STORE for s in x.src[1:]): continue
# treat MSTACK/MSELECT like SINK
if x.op in {Ops.MSTACK, Ops.MSELECT}: continue

View File

@@ -141,13 +141,10 @@ multi_pm = PatternMatcher([
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
# rewrite into calls explicitly for MULTI
(UPat(Ops.CALL, name="call"), rewrite_into_call),
(UPat(Ops.CALL, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
(UPat((Ops.CALL, Ops.AFTER, Ops.STORE), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
# we just remove the MULTI from CALLs with dtypes.void and assume they are handled by the user for custom kernels
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
# after CALL
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CALL)), name="a"),
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)),
])+replace_allreduce

View File

@@ -96,6 +96,9 @@ _tensor_spec = PatternMatcher([
# ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
# STORE in tensor graph: store a value into a target
(UPat(Ops.STORE, dtypes.void, (UPat(), UPat())), lambda: True),
# MSELECT chooses one of the multi buffers
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),