diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index 4a0cbc330d..249841339e 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -29,12 +29,13 @@ def apply_after(ctx:AllocCtx, u:UOp): while base.op is Ops.AFTER: base = base.src[0] ctx.buffer_map[u] = base -# CONTIGUOUS and ASSIGN + parents are the only nodes that get updated +# CONTIGUOUS and AFTER+STORE + parents are the only nodes that get updated add_tags = PatternMatcher([ (UPat(Ops.COPY, name="u"), disk_copy_is_buffer), - # no tag on copies that are assigned - (UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.COPY, name="c")), name="a"), - lambda a,c: a.replace(src=(a.src[0], c.rtag(())), tag=a.tag+c.tag) if a.tag and c.tag else None), + # no tag on copies that are assigned via STORE+AFTER — merge COPY tag into AFTER + (UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(name="dest"), UPat(Ops.COPY, name="c")))), name="a"), + lambda a,c,dest: a.replace(src=(a.src[0], a.src[1].replace(src=(dest, c.rtag(())))), tag=a.tag+c.tag) if a.tag and c.tag else None), + (UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE)), name="x"), tag_uop), (UPat(Ops.AFTER, name="u"), apply_after), (UPat({Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), tag_uop), (UPat(GroupOp.All, name="x"), lambda ctx,x: tag_uop(ctx,x) if x in ctx.bases else None), @@ -60,11 +61,10 @@ def replace_contig_with_store_after(u:UOp): buf = _buffer_like(u) return buf.after(buf.store(u.src[0])).rtag(u.tag) -def replace_assign_with_contig(u:UOp): +def replace_store_after_with_contig(u:UOp, src:UOp): assigned_to = u while assigned_to.op in {Ops.ASSIGN, Ops.BITCAST, Ops.AFTER}: assigned_to = assigned_to.src[0].base - if assigned_to.op is not Ops.BUFFER: - return u.src[1].contiguous(tag=u.tag) + if assigned_to.op is not Ops.BUFFER: return src.contiguous(tag=u.tag) def contiguous_mops_to_view(c:UOp): """CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range.""" @@ -129,8 +129,8 @@ pm_early_transform_tensor_graph = PatternMatcher([ # remove extra CONTIGUOUS on ASSIGN/AFTER (only when target is contiguous) (UPat(Ops.CONTIGUOUS, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="a"),), name="c"), lambda a,c: a.replace(tag=(a.tag or ())+(c.tag or ())) if a.src[0].has_buffer_identity() else None), - # replace ASSIGN with CONTIGUOUS - (UPat(Ops.ASSIGN, name="u"), replace_assign_with_contig), + # replace AFTER+STORE with CONTIGUOUS when target is not a buffer + (UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat(name="src")))), name="u"), replace_store_after_with_contig), # replace CONTIGUOUS with STORE+AFTER (UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_store_after), # remove DETACH/CONTIGUOUS_BACKWARD (allows more contiguous removal) diff --git a/tinygrad/function.py b/tinygrad/function.py index 6bf3c88706..94a3100b38 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -11,7 +11,7 @@ def add_to_ctx(ctx, x:UOp): pm_ctx = PatternMatcher([ (UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx), - (UPat((Ops.ASSIGN, Ops.CONTIGUOUS), name="x"), + (UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"), lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) else None), # strip UNIQUE from unique consts — they don't need buffer identity inside function bodies (UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"), lambda ctx,x: x.replace(src=(x.src[1],))), diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 6e4097bd93..af975bd81f 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -68,8 +68,8 @@ def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]: # compute the target path (top down) in_target_path: dict[UOp, bool] = {} for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src) - # don't flow through DETACH/ASSIGN or anything not in target path - return list(root.toposort(lambda node: node.op not in {Ops.DETACH, Ops.ASSIGN} and in_target_path[node])) + # don't flow through DETACH or anything not in target path + return list(root.toposort(lambda node: node.op is not Ops.DETACH and in_target_path[node])) def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]: grads = {root: root_grad} diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 8e2fcb96bb..8a7f001957 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -17,17 +17,13 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None: for s in rb.src: if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None -def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp): - # don't realize COPY/BUFFER_VIEW when they are the direct source of ASSIGN — the ASSIGN target buffer is the output - if x.op in {Ops.COPY, Ops.BUFFER_VIEW} and x in ctx \ - and not buf.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD): - del ctx[x] +def realize_store_after_src(ctx:dict[UOp, None], dest:UOp, src:UOp): + # don't realize COPY/BUFFER_VIEW when they are the direct source of STORE+AFTER — the target buffer is the output + if src.op in {Ops.COPY, Ops.BUFFER_VIEW} and src in ctx \ + and not dest.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD): + del ctx[src] # you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce - if buf.base in x.backward_slice_with_self: ctx[x] = None - -def unrealize_store_src(ctx:dict[UOp, None], x:UOp): - """Don't realize COPY/BUFFER_VIEW consumed by STORE inside AFTER — bufferize_to_store handles them.""" - if x in ctx: del ctx[x] + if dest.base in src.backward_slice_with_self: ctx[src] = None pm_generate_realize_map = PatternMatcher([ # always realize @@ -36,10 +32,8 @@ pm_generate_realize_map = PatternMatcher([ (UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE)), allow_any_len=True, name="tr"), realize), # realize srcs of these (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs), - # sometimes realize src of assign - (UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src), - # don't realize COPY/BUFFER_VIEW consumed by STORE inside AFTER (like realize_assign_src for ASSIGN) - (UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat({Ops.COPY, Ops.BUFFER_VIEW}, name="x"))))), unrealize_store_src), + # sometimes realize/unrealize src of store+after + (UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat.var("dest"), UPat.var("src"))))), realize_store_after_src), ]) @dataclass(frozen=True) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 5d5413e8c3..a54f46f574 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -111,6 +111,8 @@ def assign_multi(dest:UOp, src:UOp): if dest.axis != src.axis: raise RuntimeError(f"axis must match in assign {dest.axis} != {src.axis}") return dest.src[0].assign(src.src[0]).multi(src.axis) +def store_after_multi(dest:UOp, src:UOp): return dest.after(dest.store(src.src[0])).multi(src.axis) + def passthrough_multi(root:UOp, multi:UOp): return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis) @@ -136,6 +138,7 @@ multi_pm = PatternMatcher([ (UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi), (UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi), (UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi), + (UPat(Ops.AFTER, src=(UPat(Ops.MULTI), UPat(Ops.STORE, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))))), store_after_multi), (UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi), (UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"), lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)), diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index d2dc748ec3..ac672571b9 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -45,41 +45,41 @@ def found_assign(ctx:dict[UOp, UOp], assign:UOp, src:UOp): else: break ctx[x] = assign -# *** fold moved ASSIGNs/AFTERs (hack for openpilot) *** +# *** fold moved AFTERs (hack for openpilot) *** pm_fold_moved_assign = PatternMatcher([ - (UPat(Ops.ASSIGN, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")), name="assign"), found_assign), (UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")))), name="assign"), found_assign), # replace ALU sources with assign versions found above (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), ]) # movement op on INDEX as a PatternMatcher +# TODO: clean up .src[0]._shape is not None pm_mops = PatternMatcher([ (UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg) - if len(idx.src[1:]) == len(r.shape) else None), + if r.src[0]._shape is not None and len(idx.src[1:]) == len(r.shape) else None), # move movement ops after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after) (UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True), lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg) - if not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None), + if a.src[0]._shape is not None and not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None), (UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])), ]) # ***************** # 0. do some cleanup rewrites, mostly copied from the old stuff -def fix_assign_hazard(assign:UOp, target:UOp, src:UOp): +def fix_store_after_hazard(after:UOp, target:UOp, src:UOp): # PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set()) if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)): - return assign.replace(src=(target, src.contiguous())) + return after.replace(src=(after.src[0], target.store(src.contiguous()))) -def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp): +def normalize_store_after_target_chain(after:UOp, target:UOp, src:UOp): root_target = target - while root_target.op in {Ops.ASSIGN, Ops.AFTER}: root_target = root_target.src[0] + while root_target.op is Ops.AFTER: root_target = root_target.src[0] # when RHS depends on the previous assign result, break with contiguous if target in src.toposort(): src = src.contiguous() - return assign.replace(src=(root_target, src)) + return after.replace(src=(root_target, root_target.store(src))) def split_reduceop(reduce:UOp, x:UOp): if prod(reduce.shape) == 0: return None @@ -166,21 +166,18 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ # copy only to different device (UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP) if x.device == copy.device else None), - # ** assign rules ** + # ** assign rules (STORE+AFTER) ** - # collapse nested ASSIGN to the same buffer (e.g. __iadd__ in __setitem__) - (UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat()), name="src"))), lambda target, src: src), - - # move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype)) - (UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))), - lambda target, src: target.assign(src.bitcast(target.dtype))), - - # if assign target is itself an ASSIGN/AFTER chain, canonicalize to the original buffer target - (UPat(Ops.ASSIGN, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), - normalize_assign_target_chain), + # move bitcast from store+after target to source + (UPat(Ops.AFTER, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(Ops.STORE, src=(UPat(Ops.BITCAST), UPat(name="src"))))), + lambda target, src: target.after(target.store(src.bitcast(target.dtype)))), # make source contiguous if it has hazardous movement ops on the dest buffer - (UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard), + (UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(name="target"), UPat(name="src")))), name="after"), fix_store_after_hazard), + + # normalize target chain: walk through AFTERs to root, insert contiguous if needed + (UPat(Ops.AFTER, src=(UPat(Ops.AFTER, name="target"), UPat(Ops.STORE, src=(UPat(), UPat(name="src")))), name="after"), + lambda after, target, src: normalize_store_after_target_chain(after, target, src)), # ** size 0 ** @@ -374,10 +371,18 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True): sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace) # AFTER: add END to the existing STORE, return buffer with kernel dependency - if x.src[0].op is Ops.AFTER: - buf = x.src[0].src[0].buf_uop.base - stores = [s for s in x.src[0].src[1:] if s.op is Ops.STORE] - return buf.after(*[s.end(*rngs) for s in stores]) if stores else buf + if (after:=x.src[0]).op is Ops.AFTER: + buf = after.src[0].buf_uop.base + if not (stores := [s for s in after.src[1:] if s.op is Ops.STORE and s.src[0].op is Ops.INDEX]): return buf + # BUFFERIZE(INDEX(...)); store through the underlying global index instead. + ended_stores = [] + store_target = stores[0].src[0] + if store_target.src[0].op is Ops.BUFFERIZE and store_target.src[0].src[0].op is Ops.INDEX: + store_target = store_target.src[0].src[0] + if stores[0].src[1] is not store_target: # skip self-assign + end_rngs = sorted(dedup(tuple(store_target.ranges) + tuple(rngs)), key=lambda x: x.arg) + ended_stores.append(store_target.replace(dtype=sdtype).store(stores[0].src[1]).end(*end_rngs)) + return buf.after(*ended_stores) if (assign := x.src[0]).op is Ops.ASSIGN: assign_target, assign_src = assign.src[0], assign.src[1] assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index" diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 3c3215ee29..4f76db1433 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -308,20 +308,22 @@ class Tensor(OpMixin): if is_disk: self._buffer().copyin(x._data()) return self - # NOTE: assign_uop is created before AFTER embedding (uses original self.uop), - # but AFTER must be embedded before _apply_uop (so subsequent assigns see it) - assign_uop = self.uop.assign(x.uop) + # STORE+AFTER: STORE is the write effect (void), AFTER wraps the view for correct shape/ranging + store_uop = self.uop.store(x.uop) base = self.uop.base - if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity(): + if base.op in {Ops.BUFFER, Ops.AFTER} and self.uop is not base and not self.uop.has_buffer_identity(): + # view assign: inner AFTER(view, STORE) for correct shape/ranging, outer AFTER(base, inner) for dependency original_uop = self.uop - assigned_base = base.after(assign_uop) + view_after = self.uop.after(store_uop) + assigned_base = base.after(view_after) _apply_map_to_tensors({base: assigned_base}, name="Embed View Assign", walk=True) def replace_view_base(u:UOp) -> UOp: return u.replace(src=((assigned_base if u.src[0] is base else replace_view_base(u.src[0])),)+u.src[1:]) ret = Tensor(replace_view_base(original_uop), device=self.device, requires_grad=self.requires_grad) - self.replace(self._apply_uop(lambda *_: assign_uop, x)) + self.replace(self._apply_uop(lambda *_: replace_view_base(original_uop), x)) return ret - return self.replace(self._apply_uop(lambda *_: assign_uop, x)) + # simple assign: AFTER wraps self.uop (may be RESHAPE'd buffer) with STORE effect + return self.replace(self._apply_uop(lambda *_: self.uop.after(store_uop), x)) def detach(self) -> Tensor: """ @@ -1338,7 +1340,8 @@ class Tensor(OpMixin): if (t:=tref()) is not None and t is not self and t.uop is not v_uop and t.uop not in v_bw): raise RuntimeError("can't setitem on a tensor that already has other uses and requires grad") if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) - if v.uop.op is Ops.ASSIGN: v = v._apply_uop(lambda x: x.src[1]) + # __iadd__/__isub__ creates AFTER(view, STORE(view, computed)); unwrap to get the computed value + if v.uop.op is Ops.AFTER and any(s.op is Ops.STORE for s in v.uop.src[1:]): v = v._apply_uop(lambda x: x.src[1].src[1]) self.replace(self._getitem(indices, v)) return idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices) @@ -1347,14 +1350,14 @@ class Tensor(OpMixin): if is_disk: raise RuntimeError("advanced setitem is not supported for DISK tensors") if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) self.assign(self._getitem(indices, v)) - elif is_disk or self.uop.is_realized or self.uop.base.op in (Ops.AFTER, Ops.BUFFER): # basic setitem, self is realized + elif is_disk or self.uop.is_realized or self.uop.base.op is Ops.BUFFER or self.uop._base_buffer_is_realized(): # basic setitem view = self[indices] - if isinstance(v, Tensor) and v.uop.op is Ops.ASSIGN and v.uop in view.uop.base.src: return + if isinstance(v, Tensor) and v.uop.op is Ops.AFTER and v.uop in view.uop.base.src: return view.assign(v) else: # basic setitem, self is not realized if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) - # __iadd__/__isub__ on unrealized views creates a no-op ASSIGN; unwrap to get the computed value - if v.uop.op is Ops.ASSIGN: v = v._apply_uop(lambda x: x.src[1]) + # __iadd__/__isub__ creates AFTER(view, STORE(view, computed)); unwrap to get the computed value + if v.uop.op is Ops.AFTER and any(s.op is Ops.STORE for s in v.uop.src[1:]): v = v._apply_uop(lambda x: x.src[1].src[1]) self.replace(self._getitem(indices, v)) def __delitem__(self, indices) -> None: diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index c7d67d882c..7ed873b892 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -697,6 +697,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].has_buffer_identity() return self.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM} + def _base_buffer_is_realized(self) -> bool: + """Walk through AFTER chain to find if the underlying buffer is realized (has allocated memory).""" + u = self.base + while u.op is Ops.AFTER: u = u.src[0] + return u.is_realized + @property def buffer(self) -> Buffer|MultiBuffer: from tinygrad.device import Buffer, MultiBuffer @@ -1069,7 +1075,6 @@ class UPat(OpMixin): def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs) def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs) def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.match_dtype, (self,)+src, **kwargs) - def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.match_dtype, (self,x), **kwargs) def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.match_dtype, src=(self,)+src, **kwargs) def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.match_dtype, src=self, **kwargs) def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.match_dtype, src=(self,)+args, **kwargs) @@ -1516,7 +1521,7 @@ def render_marg(ctx,x:UOp): return f"({','.join(pieces)})" if len(pieces) != 1 else f"({pieces[0]},)" sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY, - Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.ASSIGN, Ops.DETACH} + Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH} pm_pyrender_extra = PatternMatcher([ (UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d: f"UOp.unique_const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"), @@ -1584,7 +1589,7 @@ def pyrender(ast:UOp) -> str: cmap = consumer_map_from_toposort(lst) not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE} always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE, - Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN} + Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END} to_render: set[UOp] = {ast} for u in lst: diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index c6a899359a..b5f3e5c5e8 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -77,8 +77,9 @@ movement_ops = PatternMatcher([ (UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True), (UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True), - # AFTER on Movement Op or ASSIGN - (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.BUFFER})),), allow_any_len=True), lambda: True), + # AFTER on Movement Op, BUFFER, COPY, or BITCAST + (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.BUFFER, Ops.BITCAST, Ops.COPY})),), allow_any_len=True), + lambda: True), ]) _tensor_spec = PatternMatcher([ @@ -96,8 +97,8 @@ _tensor_spec = PatternMatcher([ # KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER (UPat(Ops.CALL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True), - # ASSIGN has a target and a value. It can also optionally depend on other assigns - (UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])), + # ASSIGN is used internally by allreduce for precompiled function bodies + (UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2), # MSELECT chooses one of the multi buffers (UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),