mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
don't emit assign in transform_precompiled_call [pr] (#15262)
This commit is contained in:
@@ -95,7 +95,8 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
if c.src[0].op is Ops.SINK: return None
|
||||
out = _buffer_like(c)
|
||||
input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:])
|
||||
fxn = out.param_like(len(c.src)-1).assign(c.src[0]).sink()
|
||||
target = out.param_like(len(c.src)-1).shrink_to(c.shape)
|
||||
fxn = target.after(target.store(c.src[0])).sink()
|
||||
ret = out.after(c.replace(src=(fxn, *input_buffers, out), dtype=dtypes.void, tag=None))
|
||||
# if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape
|
||||
if any(isinstance(s, UOp) for s in c.shape): ret = ret.shrink(tuple((0, s) for s in c.shape))
|
||||
|
||||
@@ -27,7 +27,9 @@ def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
|
||||
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
# always realize
|
||||
(UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN}, name="tr"), realize),
|
||||
(UPat({Ops.COPY, Ops.CONTIGUOUS, Ops.ASSIGN}, name="tr"), realize),
|
||||
# realize AFTER of STORE+AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE)), allow_any_len=True, name="tr"), realize),
|
||||
# realize srcs of these
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
||||
# sometimes realize src of assign
|
||||
@@ -58,7 +60,13 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
new_srcs = []
|
||||
for s in x.src:
|
||||
new_src = s
|
||||
if s.op in {Ops.PARAM, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
|
||||
# TODO: this STORE+AFTER is very explicit, AFTER is the one being realized, and STORE needs to end ranges
|
||||
if x.op is Ops.AFTER and s.op is Ops.STORE and x in ctx.realize_map:
|
||||
realized_ranges = ctx.realize_map[x]
|
||||
assert isinstance(realized_ranges, list), "realize map must contain range list"
|
||||
closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[x][1]) if i in realized_ranges])
|
||||
new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
|
||||
elif s.op in {Ops.PARAM, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
|
||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||
elif s in ctx.realize_map:
|
||||
realized_ranges = ctx.realize_map[s]
|
||||
@@ -163,8 +171,8 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
# no ranges on kernels, they are internal
|
||||
if x.op in {Ops.CALL, Ops.LINEAR}: continue
|
||||
|
||||
# no range on after
|
||||
if x.op is Ops.AFTER: continue
|
||||
# only STORE+AFTER has range
|
||||
if x.op is Ops.AFTER and all(s.op is not Ops.STORE for s in x.src[1:]): continue
|
||||
|
||||
# treat MSTACK/MSELECT like SINK
|
||||
if x.op in {Ops.MSTACK, Ops.MSELECT}: continue
|
||||
|
||||
@@ -141,13 +141,10 @@ multi_pm = PatternMatcher([
|
||||
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
|
||||
# rewrite into calls explicitly for MULTI
|
||||
(UPat(Ops.CALL, name="call"), rewrite_into_call),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
|
||||
(UPat((Ops.CALL, Ops.AFTER, Ops.STORE), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
|
||||
# we just remove the MULTI from CALLs with dtypes.void and assume they are handled by the user for custom kernels
|
||||
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
|
||||
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)),
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
# after CALL
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CALL)), name="a"),
|
||||
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)),
|
||||
])+replace_allreduce
|
||||
|
||||
@@ -96,6 +96,9 @@ _tensor_spec = PatternMatcher([
|
||||
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
||||
|
||||
# STORE in tensor graph: store a value into a target
|
||||
(UPat(Ops.STORE, dtypes.void, (UPat(), UPat())), lambda: True),
|
||||
|
||||
# MSELECT chooses one of the multi buffers
|
||||
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
|
||||
|
||||
|
||||
Reference in New Issue
Block a user