mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Tensor.assign is store+after [pr] (#15288)
* Tensor.assign is store+after [pr] * put that back
This commit is contained in:
@@ -29,12 +29,13 @@ def apply_after(ctx:AllocCtx, u:UOp):
|
||||
while base.op is Ops.AFTER: base = base.src[0]
|
||||
ctx.buffer_map[u] = base
|
||||
|
||||
# CONTIGUOUS and ASSIGN + parents are the only nodes that get updated
|
||||
# CONTIGUOUS and AFTER+STORE + parents are the only nodes that get updated
|
||||
add_tags = PatternMatcher([
|
||||
(UPat(Ops.COPY, name="u"), disk_copy_is_buffer),
|
||||
# no tag on copies that are assigned
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.COPY, name="c")), name="a"),
|
||||
lambda a,c: a.replace(src=(a.src[0], c.rtag(())), tag=a.tag+c.tag) if a.tag and c.tag else None),
|
||||
# no tag on copies that are assigned via STORE+AFTER — merge COPY tag into AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(name="dest"), UPat(Ops.COPY, name="c")))), name="a"),
|
||||
lambda a,c,dest: a.replace(src=(a.src[0], a.src[1].replace(src=(dest, c.rtag(())))), tag=a.tag+c.tag) if a.tag and c.tag else None),
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE)), name="x"), tag_uop),
|
||||
(UPat(Ops.AFTER, name="u"), apply_after),
|
||||
(UPat({Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), tag_uop),
|
||||
(UPat(GroupOp.All, name="x"), lambda ctx,x: tag_uop(ctx,x) if x in ctx.bases else None),
|
||||
@@ -60,11 +61,10 @@ def replace_contig_with_store_after(u:UOp):
|
||||
buf = _buffer_like(u)
|
||||
return buf.after(buf.store(u.src[0])).rtag(u.tag)
|
||||
|
||||
def replace_assign_with_contig(u:UOp):
|
||||
def replace_store_after_with_contig(u:UOp, src:UOp):
|
||||
assigned_to = u
|
||||
while assigned_to.op in {Ops.ASSIGN, Ops.BITCAST, Ops.AFTER}: assigned_to = assigned_to.src[0].base
|
||||
if assigned_to.op is not Ops.BUFFER:
|
||||
return u.src[1].contiguous(tag=u.tag)
|
||||
if assigned_to.op is not Ops.BUFFER: return src.contiguous(tag=u.tag)
|
||||
|
||||
def contiguous_mops_to_view(c:UOp):
|
||||
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range."""
|
||||
@@ -129,8 +129,8 @@ pm_early_transform_tensor_graph = PatternMatcher([
|
||||
# remove extra CONTIGUOUS on ASSIGN/AFTER (only when target is contiguous)
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="a"),), name="c"),
|
||||
lambda a,c: a.replace(tag=(a.tag or ())+(c.tag or ())) if a.src[0].has_buffer_identity() else None),
|
||||
# replace ASSIGN with CONTIGUOUS
|
||||
(UPat(Ops.ASSIGN, name="u"), replace_assign_with_contig),
|
||||
# replace AFTER+STORE with CONTIGUOUS when target is not a buffer
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat(name="src")))), name="u"), replace_store_after_with_contig),
|
||||
# replace CONTIGUOUS with STORE+AFTER
|
||||
(UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_store_after),
|
||||
# remove DETACH/CONTIGUOUS_BACKWARD (allows more contiguous removal)
|
||||
|
||||
@@ -11,7 +11,7 @@ def add_to_ctx(ctx, x:UOp):
|
||||
|
||||
pm_ctx = PatternMatcher([
|
||||
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
|
||||
(UPat((Ops.ASSIGN, Ops.CONTIGUOUS), name="x"),
|
||||
(UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"),
|
||||
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) else None),
|
||||
# strip UNIQUE from unique consts — they don't need buffer identity inside function bodies
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"), lambda ctx,x: x.replace(src=(x.src[1],))),
|
||||
|
||||
@@ -68,8 +68,8 @@ def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
||||
# compute the target path (top down)
|
||||
in_target_path: dict[UOp, bool] = {}
|
||||
for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
|
||||
# don't flow through DETACH/ASSIGN or anything not in target path
|
||||
return list(root.toposort(lambda node: node.op not in {Ops.DETACH, Ops.ASSIGN} and in_target_path[node]))
|
||||
# don't flow through DETACH or anything not in target path
|
||||
return list(root.toposort(lambda node: node.op is not Ops.DETACH and in_target_path[node]))
|
||||
|
||||
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
|
||||
grads = {root: root_grad}
|
||||
|
||||
@@ -17,17 +17,13 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
||||
for s in rb.src:
|
||||
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
|
||||
# don't realize COPY/BUFFER_VIEW when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
|
||||
if x.op in {Ops.COPY, Ops.BUFFER_VIEW} and x in ctx \
|
||||
and not buf.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
|
||||
del ctx[x]
|
||||
def realize_store_after_src(ctx:dict[UOp, None], dest:UOp, src:UOp):
|
||||
# don't realize COPY/BUFFER_VIEW when they are the direct source of STORE+AFTER — the target buffer is the output
|
||||
if src.op in {Ops.COPY, Ops.BUFFER_VIEW} and src in ctx \
|
||||
and not dest.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
|
||||
del ctx[src]
|
||||
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
|
||||
if buf.base in x.backward_slice_with_self: ctx[x] = None
|
||||
|
||||
def unrealize_store_src(ctx:dict[UOp, None], x:UOp):
|
||||
"""Don't realize COPY/BUFFER_VIEW consumed by STORE inside AFTER — bufferize_to_store handles them."""
|
||||
if x in ctx: del ctx[x]
|
||||
if dest.base in src.backward_slice_with_self: ctx[src] = None
|
||||
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
# always realize
|
||||
@@ -36,10 +32,8 @@ pm_generate_realize_map = PatternMatcher([
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE)), allow_any_len=True, name="tr"), realize),
|
||||
# realize srcs of these
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
||||
# sometimes realize src of assign
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
|
||||
# don't realize COPY/BUFFER_VIEW consumed by STORE inside AFTER (like realize_assign_src for ASSIGN)
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat({Ops.COPY, Ops.BUFFER_VIEW}, name="x"))))), unrealize_store_src),
|
||||
# sometimes realize/unrealize src of store+after
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat.var("dest"), UPat.var("src"))))), realize_store_after_src),
|
||||
])
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -111,6 +111,8 @@ def assign_multi(dest:UOp, src:UOp):
|
||||
if dest.axis != src.axis: raise RuntimeError(f"axis must match in assign {dest.axis} != {src.axis}")
|
||||
return dest.src[0].assign(src.src[0]).multi(src.axis)
|
||||
|
||||
def store_after_multi(dest:UOp, src:UOp): return dest.after(dest.store(src.src[0])).multi(src.axis)
|
||||
|
||||
def passthrough_multi(root:UOp, multi:UOp):
|
||||
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
|
||||
|
||||
@@ -136,6 +138,7 @@ multi_pm = PatternMatcher([
|
||||
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
|
||||
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI), UPat(Ops.STORE, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))))), store_after_multi),
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
|
||||
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
|
||||
|
||||
@@ -45,41 +45,41 @@ def found_assign(ctx:dict[UOp, UOp], assign:UOp, src:UOp):
|
||||
else: break
|
||||
ctx[x] = assign
|
||||
|
||||
# *** fold moved ASSIGNs/AFTERs (hack for openpilot) ***
|
||||
# *** fold moved AFTERs (hack for openpilot) ***
|
||||
pm_fold_moved_assign = PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")), name="assign"), found_assign),
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")))), name="assign"), found_assign),
|
||||
# replace ALU sources with assign versions found above
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
# movement op on INDEX as a PatternMatcher
|
||||
# TODO: clean up .src[0]._shape is not None
|
||||
pm_mops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)
|
||||
if len(idx.src[1:]) == len(r.shape) else None),
|
||||
if r.src[0]._shape is not None and len(idx.src[1:]) == len(r.shape) else None),
|
||||
# move movement ops after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
|
||||
(UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True),
|
||||
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)
|
||||
if not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None),
|
||||
if a.src[0]._shape is not None and not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None),
|
||||
(UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
|
||||
])
|
||||
|
||||
# *****************
|
||||
# 0. do some cleanup rewrites, mostly copied from the old stuff
|
||||
|
||||
def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
|
||||
def fix_store_after_hazard(after:UOp, target:UOp, src:UOp):
|
||||
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
|
||||
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
|
||||
if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)):
|
||||
return assign.replace(src=(target, src.contiguous()))
|
||||
return after.replace(src=(after.src[0], target.store(src.contiguous())))
|
||||
|
||||
def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp):
|
||||
def normalize_store_after_target_chain(after:UOp, target:UOp, src:UOp):
|
||||
root_target = target
|
||||
while root_target.op in {Ops.ASSIGN, Ops.AFTER}: root_target = root_target.src[0]
|
||||
while root_target.op is Ops.AFTER: root_target = root_target.src[0]
|
||||
# when RHS depends on the previous assign result, break with contiguous
|
||||
if target in src.toposort(): src = src.contiguous()
|
||||
return assign.replace(src=(root_target, src))
|
||||
return after.replace(src=(root_target, root_target.store(src)))
|
||||
|
||||
def split_reduceop(reduce:UOp, x:UOp):
|
||||
if prod(reduce.shape) == 0: return None
|
||||
@@ -166,21 +166,18 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# copy only to different device
|
||||
(UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP) if x.device == copy.device else None),
|
||||
|
||||
# ** assign rules **
|
||||
# ** assign rules (STORE+AFTER) **
|
||||
|
||||
# collapse nested ASSIGN to the same buffer (e.g. __iadd__ in __setitem__)
|
||||
(UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat()), name="src"))), lambda target, src: src),
|
||||
|
||||
# move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype))
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))),
|
||||
lambda target, src: target.assign(src.bitcast(target.dtype))),
|
||||
|
||||
# if assign target is itself an ASSIGN/AFTER chain, canonicalize to the original buffer target
|
||||
(UPat(Ops.ASSIGN, src=(UPat({Ops.ASSIGN, Ops.AFTER}, name="target"), UPat(name="src")), allow_any_len=True, name="assign"),
|
||||
normalize_assign_target_chain),
|
||||
# move bitcast from store+after target to source
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(Ops.STORE, src=(UPat(Ops.BITCAST), UPat(name="src"))))),
|
||||
lambda target, src: target.after(target.store(src.bitcast(target.dtype)))),
|
||||
|
||||
# make source contiguous if it has hazardous movement ops on the dest buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
|
||||
(UPat(Ops.AFTER, src=(UPat(), UPat(Ops.STORE, src=(UPat(name="target"), UPat(name="src")))), name="after"), fix_store_after_hazard),
|
||||
|
||||
# normalize target chain: walk through AFTERs to root, insert contiguous if needed
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.AFTER, name="target"), UPat(Ops.STORE, src=(UPat(), UPat(name="src")))), name="after"),
|
||||
lambda after, target, src: normalize_store_after_target_chain(after, target, src)),
|
||||
|
||||
# ** size 0 **
|
||||
|
||||
@@ -374,10 +371,18 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
||||
|
||||
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
|
||||
# AFTER: add END to the existing STORE, return buffer with kernel dependency
|
||||
if x.src[0].op is Ops.AFTER:
|
||||
buf = x.src[0].src[0].buf_uop.base
|
||||
stores = [s for s in x.src[0].src[1:] if s.op is Ops.STORE]
|
||||
return buf.after(*[s.end(*rngs) for s in stores]) if stores else buf
|
||||
if (after:=x.src[0]).op is Ops.AFTER:
|
||||
buf = after.src[0].buf_uop.base
|
||||
if not (stores := [s for s in after.src[1:] if s.op is Ops.STORE and s.src[0].op is Ops.INDEX]): return buf
|
||||
# BUFFERIZE(INDEX(...)); store through the underlying global index instead.
|
||||
ended_stores = []
|
||||
store_target = stores[0].src[0]
|
||||
if store_target.src[0].op is Ops.BUFFERIZE and store_target.src[0].src[0].op is Ops.INDEX:
|
||||
store_target = store_target.src[0].src[0]
|
||||
if stores[0].src[1] is not store_target: # skip self-assign
|
||||
end_rngs = sorted(dedup(tuple(store_target.ranges) + tuple(rngs)), key=lambda x: x.arg)
|
||||
ended_stores.append(store_target.replace(dtype=sdtype).store(stores[0].src[1]).end(*end_rngs))
|
||||
return buf.after(*ended_stores)
|
||||
if (assign := x.src[0]).op is Ops.ASSIGN:
|
||||
assign_target, assign_src = assign.src[0], assign.src[1]
|
||||
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
|
||||
|
||||
@@ -308,20 +308,22 @@ class Tensor(OpMixin):
|
||||
if is_disk:
|
||||
self._buffer().copyin(x._data())
|
||||
return self
|
||||
# NOTE: assign_uop is created before AFTER embedding (uses original self.uop),
|
||||
# but AFTER must be embedded before _apply_uop (so subsequent assigns see it)
|
||||
assign_uop = self.uop.assign(x.uop)
|
||||
# STORE+AFTER: STORE is the write effect (void), AFTER wraps the view for correct shape/ranging
|
||||
store_uop = self.uop.store(x.uop)
|
||||
base = self.uop.base
|
||||
if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity():
|
||||
if base.op in {Ops.BUFFER, Ops.AFTER} and self.uop is not base and not self.uop.has_buffer_identity():
|
||||
# view assign: inner AFTER(view, STORE) for correct shape/ranging, outer AFTER(base, inner) for dependency
|
||||
original_uop = self.uop
|
||||
assigned_base = base.after(assign_uop)
|
||||
view_after = self.uop.after(store_uop)
|
||||
assigned_base = base.after(view_after)
|
||||
_apply_map_to_tensors({base: assigned_base}, name="Embed View Assign", walk=True)
|
||||
def replace_view_base(u:UOp) -> UOp:
|
||||
return u.replace(src=((assigned_base if u.src[0] is base else replace_view_base(u.src[0])),)+u.src[1:])
|
||||
ret = Tensor(replace_view_base(original_uop), device=self.device, requires_grad=self.requires_grad)
|
||||
self.replace(self._apply_uop(lambda *_: assign_uop, x))
|
||||
self.replace(self._apply_uop(lambda *_: replace_view_base(original_uop), x))
|
||||
return ret
|
||||
return self.replace(self._apply_uop(lambda *_: assign_uop, x))
|
||||
# simple assign: AFTER wraps self.uop (may be RESHAPE'd buffer) with STORE effect
|
||||
return self.replace(self._apply_uop(lambda *_: self.uop.after(store_uop), x))
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
"""
|
||||
@@ -1338,7 +1340,8 @@ class Tensor(OpMixin):
|
||||
if (t:=tref()) is not None and t is not self and t.uop is not v_uop and t.uop not in v_bw):
|
||||
raise RuntimeError("can't setitem on a tensor that already has other uses and requires grad")
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
if v.uop.op is Ops.ASSIGN: v = v._apply_uop(lambda x: x.src[1])
|
||||
# __iadd__/__isub__ creates AFTER(view, STORE(view, computed)); unwrap to get the computed value
|
||||
if v.uop.op is Ops.AFTER and any(s.op is Ops.STORE for s in v.uop.src[1:]): v = v._apply_uop(lambda x: x.src[1].src[1])
|
||||
self.replace(self._getitem(indices, v))
|
||||
return
|
||||
idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices)
|
||||
@@ -1347,14 +1350,14 @@ class Tensor(OpMixin):
|
||||
if is_disk: raise RuntimeError("advanced setitem is not supported for DISK tensors")
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
self.assign(self._getitem(indices, v))
|
||||
elif is_disk or self.uop.is_realized or self.uop.base.op in (Ops.AFTER, Ops.BUFFER): # basic setitem, self is realized
|
||||
elif is_disk or self.uop.is_realized or self.uop.base.op is Ops.BUFFER or self.uop._base_buffer_is_realized(): # basic setitem
|
||||
view = self[indices]
|
||||
if isinstance(v, Tensor) and v.uop.op is Ops.ASSIGN and v.uop in view.uop.base.src: return
|
||||
if isinstance(v, Tensor) and v.uop.op is Ops.AFTER and v.uop in view.uop.base.src: return
|
||||
view.assign(v)
|
||||
else: # basic setitem, self is not realized
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
# __iadd__/__isub__ on unrealized views creates a no-op ASSIGN; unwrap to get the computed value
|
||||
if v.uop.op is Ops.ASSIGN: v = v._apply_uop(lambda x: x.src[1])
|
||||
# __iadd__/__isub__ creates AFTER(view, STORE(view, computed)); unwrap to get the computed value
|
||||
if v.uop.op is Ops.AFTER and any(s.op is Ops.STORE for s in v.uop.src[1:]): v = v._apply_uop(lambda x: x.src[1].src[1])
|
||||
self.replace(self._getitem(indices, v))
|
||||
|
||||
def __delitem__(self, indices) -> None:
|
||||
|
||||
@@ -697,6 +697,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].has_buffer_identity()
|
||||
return self.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM}
|
||||
|
||||
def _base_buffer_is_realized(self) -> bool:
|
||||
"""Walk through AFTER chain to find if the underlying buffer is realized (has allocated memory)."""
|
||||
u = self.base
|
||||
while u.op is Ops.AFTER: u = u.src[0]
|
||||
return u.is_realized
|
||||
|
||||
@property
|
||||
def buffer(self) -> Buffer|MultiBuffer:
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
@@ -1069,7 +1075,6 @@ class UPat(OpMixin):
|
||||
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
|
||||
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.match_dtype, (self,)+src, **kwargs)
|
||||
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.match_dtype, (self,x), **kwargs)
|
||||
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.match_dtype, src=(self,)+src, **kwargs)
|
||||
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.match_dtype, src=self, **kwargs)
|
||||
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.match_dtype, src=(self,)+args, **kwargs)
|
||||
@@ -1516,7 +1521,7 @@ def render_marg(ctx,x:UOp):
|
||||
return f"({','.join(pieces)})" if len(pieces) != 1 else f"({pieces[0]},)"
|
||||
|
||||
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
|
||||
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.ASSIGN, Ops.DETACH}
|
||||
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH}
|
||||
pm_pyrender_extra = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"),
|
||||
lambda x,u,d: f"UOp.unique_const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
|
||||
@@ -1584,7 +1589,7 @@ def pyrender(ast:UOp) -> str:
|
||||
cmap = consumer_map_from_toposort(lst)
|
||||
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
|
||||
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
|
||||
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
|
||||
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END}
|
||||
|
||||
to_render: set[UOp] = {ast}
|
||||
for u in lst:
|
||||
|
||||
@@ -77,8 +77,9 @@ movement_ops = PatternMatcher([
|
||||
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
|
||||
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# AFTER on Movement Op or ASSIGN
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.BUFFER})),), allow_any_len=True), lambda: True),
|
||||
# AFTER on Movement Op, BUFFER, COPY, or BITCAST
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.BUFFER, Ops.BITCAST, Ops.COPY})),), allow_any_len=True),
|
||||
lambda: True),
|
||||
])
|
||||
|
||||
_tensor_spec = PatternMatcher([
|
||||
@@ -96,8 +97,8 @@ _tensor_spec = PatternMatcher([
|
||||
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||
(UPat(Ops.CALL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
|
||||
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
||||
# ASSIGN is used internally by allreduce for precompiled function bodies
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2),
|
||||
|
||||
# MSELECT chooses one of the multi buffers
|
||||
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
|
||||
|
||||
Reference in New Issue
Block a user