Tensor.assign is store+after [pr] (#15288)

* Tensor.assign is store+after [pr]

* put that back
This commit is contained in:
chenyu
2026-03-16 04:04:55 -04:00
committed by GitHub
parent 08662bc4ab
commit a0d1444790
9 changed files with 81 additions and 70 deletions

View File

@@ -29,12 +29,13 @@ def apply_after(ctx:AllocCtx, u:UOp):
while base.op is Ops.AFTER: base = base.src[0]
ctx.buffer_map[u] = base
# CONTIGUOUS and ASSIGN + parents are the only nodes that get updated
# CONTIGUOUS and AFTER+STORE + parents are the only nodes that get updated
add_tags = PatternMatcher([
(UPat(Ops.COPY, name="u"), disk_copy_is_buffer),
# no tag on copies that are assigned
(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.COPY, name="c")), name="a"),
lambda a,c: a.replace(src=(a.src[0], c.rtag(())), tag=a.tag+c.tag) if a.tag and c.tag else None),
# no tag on copies that are assigned via STORE+AFTER — merge COPY tag into AFTER
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(name="dest"), UPat(Ops.COPY, name="c")))), name="a"),
lambda a,c,dest: a.replace(src=(a.src[0], a.src[1].replace(src=(dest, c.rtag(())))), tag=a.tag+c.tag) if a.tag and c.tag else None),
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE)), name="x"), tag_uop),
(UPat(Ops.AFTER, name="u"), apply_after),
(UPat({Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), tag_uop),
(UPat(GroupOp.All, name="x"), lambda ctx,x: tag_uop(ctx,x) if x in ctx.bases else None),
@@ -60,11 +61,10 @@ def replace_contig_with_store_after(u:UOp):
buf = _buffer_like(u)
return buf.after(buf.store(u.src[0])).rtag(u.tag)
def replace_assign_with_contig(u:UOp):
def replace_store_after_with_contig(u:UOp, src:UOp):
assigned_to = u
while assigned_to.op in {Ops.ASSIGN, Ops.BITCAST, Ops.AFTER}: assigned_to = assigned_to.src[0].base
if assigned_to.op is not Ops.BUFFER:
return u.src[1].contiguous(tag=u.tag)
if assigned_to.op is not Ops.BUFFER: return src.contiguous(tag=u.tag)
def contiguous_mops_to_view(c:UOp):
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range."""
@@ -129,8 +129,8 @@ pm_early_transform_tensor_graph = PatternMatcher([
# remove extra CONTIGUOUS on ASSIGN/AFTER (only when target is contiguous)
(UPat(Ops.CONTIGUOUS, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="a"),), name="c"),
lambda a,c: a.replace(tag=(a.tag or ())+(c.tag or ())) if a.src[0].has_buffer_identity() else None),
# replace ASSIGN with CONTIGUOUS
(UPat(Ops.ASSIGN, name="u"), replace_assign_with_contig),
# replace AFTER+STORE with CONTIGUOUS when target is not a buffer
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat(name="src")))), name="u"), replace_store_after_with_contig),
# replace CONTIGUOUS with STORE+AFTER
(UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_store_after),
# remove DETACH/CONTIGUOUS_BACKWARD (allows more contiguous removal)

View File

@@ -11,7 +11,7 @@ def add_to_ctx(ctx, x:UOp):
pm_ctx = PatternMatcher([
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
(UPat((Ops.ASSIGN, Ops.CONTIGUOUS), name="x"),
(UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"),
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) else None),
# strip UNIQUE from unique consts — they don't need buffer identity inside function bodies
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"), lambda ctx,x: x.replace(src=(x.src[1],))),

View File

@@ -68,8 +68,8 @@ def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
# compute the target path (top down)
in_target_path: dict[UOp, bool] = {}
for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
# don't flow through DETACH/ASSIGN or anything not in target path
return list(root.toposort(lambda node: node.op not in {Ops.DETACH, Ops.ASSIGN} and in_target_path[node]))
# don't flow through DETACH or anything not in target path
return list(root.toposort(lambda node: node.op is not Ops.DETACH and in_target_path[node]))
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
grads = {root: root_grad}

View File

@@ -17,17 +17,13 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src:
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
# don't realize COPY/BUFFER_VIEW when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
if x.op in {Ops.COPY, Ops.BUFFER_VIEW} and x in ctx \
and not buf.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
del ctx[x]
def realize_store_after_src(ctx:dict[UOp, None], dest:UOp, src:UOp):
# don't realize COPY/BUFFER_VIEW when they are the direct source of STORE+AFTER — the target buffer is the output
if src.op in {Ops.COPY, Ops.BUFFER_VIEW} and src in ctx \
and not dest.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
del ctx[src]
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
if buf.base in x.backward_slice_with_self: ctx[x] = None
def unrealize_store_src(ctx:dict[UOp, None], x:UOp):
"""Don't realize COPY/BUFFER_VIEW consumed by STORE inside AFTER — bufferize_to_store handles them."""
if x in ctx: del ctx[x]
if dest.base in src.backward_slice_with_self: ctx[src] = None
pm_generate_realize_map = PatternMatcher([
# always realize
@@ -36,10 +32,8 @@ pm_generate_realize_map = PatternMatcher([
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE)), allow_any_len=True, name="tr"), realize),
# realize srcs of these
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
# sometimes realize src of assign
(UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
# don't realize COPY/BUFFER_VIEW consumed by STORE inside AFTER (like realize_assign_src for ASSIGN)
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat({Ops.COPY, Ops.BUFFER_VIEW}, name="x"))))), unrealize_store_src),
# sometimes realize/unrealize src of store+after
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat.var("dest"), UPat.var("src"))))), realize_store_after_src),
])
@dataclass(frozen=True)

View File

@@ -111,6 +111,8 @@ def assign_multi(dest:UOp, src:UOp):
if dest.axis != src.axis: raise RuntimeError(f"axis must match in assign {dest.axis} != {src.axis}")
return dest.src[0].assign(src.src[0]).multi(src.axis)
def store_after_multi(dest:UOp, src:UOp): return dest.after(dest.store(src.src[0])).multi(src.axis)
def passthrough_multi(root:UOp, multi:UOp):
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
@@ -136,6 +138,7 @@ multi_pm = PatternMatcher([
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI), UPat(Ops.STORE, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))))), store_after_multi),
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),

View File

@@ -45,41 +45,41 @@ def found_assign(ctx:dict[UOp, UOp], assign:UOp, src:UOp):
else: break
ctx[x] = assign
# *** fold moved ASSIGNs/AFTERs (hack for openpilot) ***
# *** fold moved AFTERs (hack for openpilot) ***
pm_fold_moved_assign = PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")), name="assign"), found_assign),
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")))), name="assign"), found_assign),
# replace ALU sources with assign versions found above
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
# movement op on INDEX as a PatternMatcher
# TODO: clean up .src[0]._shape is not None
pm_mops = PatternMatcher([
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)
if len(idx.src[1:]) == len(r.shape) else None),
if r.src[0]._shape is not None and len(idx.src[1:]) == len(r.shape) else None),
# move movement ops after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
(UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True),
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)
if not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None),
if a.src[0]._shape is not None and not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None),
(UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
])
# *****************
# 0. do some cleanup rewrites, mostly copied from the old stuff
def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
def fix_store_after_hazard(after:UOp, target:UOp, src:UOp):
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)):
return assign.replace(src=(target, src.contiguous()))
return after.replace(src=(after.src[0], target.store(src.contiguous())))
def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp):
def normalize_store_after_target_chain(after:UOp, target:UOp, src:UOp):
root_target = target
while root_target.op in {Ops.ASSIGN, Ops.AFTER}: root_target = root_target.src[0]
while root_target.op is Ops.AFTER: root_target = root_target.src[0]
# when RHS depends on the previous assign result, break with contiguous
if target in src.toposort(): src = src.contiguous()
return assign.replace(src=(root_target, src))
return after.replace(src=(root_target, root_target.store(src)))
def split_reduceop(reduce:UOp, x:UOp):
if prod(reduce.shape) == 0: return None
@@ -166,21 +166,18 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# copy only to different device
(UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP) if x.device == copy.device else None),
# ** assign rules **
# ** assign rules (STORE+AFTER) **
# collapse nested ASSIGN to the same buffer (e.g. __iadd__ in __setitem__)
(UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat()), name="src"))), lambda target, src: src),
# move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype))
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))),
lambda target, src: target.assign(src.bitcast(target.dtype))),
# if assign target is itself an ASSIGN/AFTER chain, canonicalize to the original buffer target
(UPat(Ops.ASSIGN, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="target"), UPat(name="src")), allow_any_len=True, name="assign"),
normalize_assign_target_chain),
# move bitcast from store+after target to source
(UPat(Ops.AFTER, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(Ops.STORE, src=(UPat(Ops.BITCAST), UPat(name="src"))))),
lambda target, src: target.after(target.store(src.bitcast(target.dtype)))),
# make source contiguous if it has hazardous movement ops on the dest buffer
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(name="target"), UPat(name="src")))), name="after"), fix_store_after_hazard),
# normalize target chain: walk through AFTERs to root, insert contiguous if needed
(UPat(Ops.AFTER, src=(UPat(Ops.AFTER, name="target"), UPat(Ops.STORE, src=(UPat(), UPat(name="src")))), name="after"),
lambda after, target, src: normalize_store_after_target_chain(after, target, src)),
# ** size 0 **
@@ -374,10 +371,18 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
# AFTER: add END to the existing STORE, return buffer with kernel dependency
if x.src[0].op is Ops.AFTER:
buf = x.src[0].src[0].buf_uop.base
stores = [s for s in x.src[0].src[1:] if s.op is Ops.STORE]
return buf.after(*[s.end(*rngs) for s in stores]) if stores else buf
if (after:=x.src[0]).op is Ops.AFTER:
buf = after.src[0].buf_uop.base
if not (stores := [s for s in after.src[1:] if s.op is Ops.STORE and s.src[0].op is Ops.INDEX]): return buf
# BUFFERIZE(INDEX(...)); store through the underlying global index instead.
ended_stores = []
store_target = stores[0].src[0]
if store_target.src[0].op is Ops.BUFFERIZE and store_target.src[0].src[0].op is Ops.INDEX:
store_target = store_target.src[0].src[0]
if stores[0].src[1] is not store_target: # skip self-assign
end_rngs = sorted(dedup(tuple(store_target.ranges) + tuple(rngs)), key=lambda x: x.arg)
ended_stores.append(store_target.replace(dtype=sdtype).store(stores[0].src[1]).end(*end_rngs))
return buf.after(*ended_stores)
if (assign := x.src[0]).op is Ops.ASSIGN:
assign_target, assign_src = assign.src[0], assign.src[1]
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"

View File

@@ -308,20 +308,22 @@ class Tensor(OpMixin):
if is_disk:
self._buffer().copyin(x._data())
return self
# NOTE: assign_uop is created before AFTER embedding (uses original self.uop),
# but AFTER must be embedded before _apply_uop (so subsequent assigns see it)
assign_uop = self.uop.assign(x.uop)
# STORE+AFTER: STORE is the write effect (void), AFTER wraps the view for correct shape/ranging
store_uop = self.uop.store(x.uop)
base = self.uop.base
if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity():
if base.op in {Ops.BUFFER, Ops.AFTER} and self.uop is not base and not self.uop.has_buffer_identity():
# view assign: inner AFTER(view, STORE) for correct shape/ranging, outer AFTER(base, inner) for dependency
original_uop = self.uop
assigned_base = base.after(assign_uop)
view_after = self.uop.after(store_uop)
assigned_base = base.after(view_after)
_apply_map_to_tensors({base: assigned_base}, name="Embed View Assign", walk=True)
def replace_view_base(u:UOp) -> UOp:
return u.replace(src=((assigned_base if u.src[0] is base else replace_view_base(u.src[0])),)+u.src[1:])
ret = Tensor(replace_view_base(original_uop), device=self.device, requires_grad=self.requires_grad)
self.replace(self._apply_uop(lambda *_: assign_uop, x))
self.replace(self._apply_uop(lambda *_: replace_view_base(original_uop), x))
return ret
return self.replace(self._apply_uop(lambda *_: assign_uop, x))
# simple assign: AFTER wraps self.uop (may be RESHAPE'd buffer) with STORE effect
return self.replace(self._apply_uop(lambda *_: self.uop.after(store_uop), x))
def detach(self) -> Tensor:
"""
@@ -1338,7 +1340,8 @@ class Tensor(OpMixin):
if (t:=tref()) is not None and t is not self and t.uop is not v_uop and t.uop not in v_bw):
raise RuntimeError("can't setitem on a tensor that already has other uses and requires grad")
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
if v.uop.op is Ops.ASSIGN: v = v._apply_uop(lambda x: x.src[1])
# __iadd__/__isub__ creates AFTER(view, STORE(view, computed)); unwrap to get the computed value
if v.uop.op is Ops.AFTER and any(s.op is Ops.STORE for s in v.uop.src[1:]): v = v._apply_uop(lambda x: x.src[1].src[1])
self.replace(self._getitem(indices, v))
return
idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices)
@@ -1347,14 +1350,14 @@ class Tensor(OpMixin):
if is_disk: raise RuntimeError("advanced setitem is not supported for DISK tensors")
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
self.assign(self._getitem(indices, v))
elif is_disk or self.uop.is_realized or self.uop.base.op in (Ops.AFTER, Ops.BUFFER): # basic setitem, self is realized
elif is_disk or self.uop.is_realized or self.uop.base.op is Ops.BUFFER or self.uop._base_buffer_is_realized(): # basic setitem
view = self[indices]
if isinstance(v, Tensor) and v.uop.op is Ops.ASSIGN and v.uop in view.uop.base.src: return
if isinstance(v, Tensor) and v.uop.op is Ops.AFTER and v.uop in view.uop.base.src: return
view.assign(v)
else: # basic setitem, self is not realized
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
# __iadd__/__isub__ on unrealized views creates a no-op ASSIGN; unwrap to get the computed value
if v.uop.op is Ops.ASSIGN: v = v._apply_uop(lambda x: x.src[1])
# __iadd__/__isub__ creates AFTER(view, STORE(view, computed)); unwrap to get the computed value
if v.uop.op is Ops.AFTER and any(s.op is Ops.STORE for s in v.uop.src[1:]): v = v._apply_uop(lambda x: x.src[1].src[1])
self.replace(self._getitem(indices, v))
def __delitem__(self, indices) -> None:

View File

@@ -697,6 +697,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].has_buffer_identity()
return self.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM}
def _base_buffer_is_realized(self) -> bool:
"""Walk through AFTER chain to find if the underlying buffer is realized (has allocated memory)."""
u = self.base
while u.op is Ops.AFTER: u = u.src[0]
return u.is_realized
@property
def buffer(self) -> Buffer|MultiBuffer:
from tinygrad.device import Buffer, MultiBuffer
@@ -1069,7 +1075,6 @@ class UPat(OpMixin):
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.match_dtype, (self,)+src, **kwargs)
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.match_dtype, (self,x), **kwargs)
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.match_dtype, src=(self,)+src, **kwargs)
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.match_dtype, src=self, **kwargs)
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.match_dtype, src=(self,)+args, **kwargs)
@@ -1516,7 +1521,7 @@ def render_marg(ctx,x:UOp):
return f"({','.join(pieces)})" if len(pieces) != 1 else f"({pieces[0]},)"
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.ASSIGN, Ops.DETACH}
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH}
pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"),
lambda x,u,d: f"UOp.unique_const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
@@ -1584,7 +1589,7 @@ def pyrender(ast:UOp) -> str:
cmap = consumer_map_from_toposort(lst)
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END}
to_render: set[UOp] = {ast}
for u in lst:

View File

@@ -77,8 +77,9 @@ movement_ops = PatternMatcher([
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
# AFTER on Movement Op or ASSIGN
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.BUFFER})),), allow_any_len=True), lambda: True),
# AFTER on Movement Op, BUFFER, COPY, or BITCAST
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.BUFFER, Ops.BITCAST, Ops.COPY})),), allow_any_len=True),
lambda: True),
])
_tensor_spec = PatternMatcher([
@@ -96,8 +97,8 @@ _tensor_spec = PatternMatcher([
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
(UPat(Ops.CALL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
# ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
# ASSIGN is used internally by allreduce for precompiled function bodies
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2),
# MSELECT chooses one of the multi buffers
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),