move assign chain fix to rangeify (#14829)

This commit is contained in:
chenyu
2026-02-17 09:40:34 -05:00
committed by GitHub
parent a2586e4c70
commit f07898c68a
2 changed files with 11 additions and 8 deletions

View File

@@ -55,6 +55,13 @@ def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
if any(s is target.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.PARAM})):
return assign.replace(src=(target, src.contiguous()))
def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp):
root_target = target
while root_target.op is Ops.ASSIGN: root_target = root_target.src[0]
# when RHS depends on the previous assign result, break with contiguous
if target in src.toposort(): src = src.contiguous()
return assign.replace(src=(root_target, src))
def split_reduceop(reduce:UOp, x:UOp):
if prod(reduce.shape) == 0: return None
if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))<getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return None
@@ -149,6 +156,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src")), name="assign"),
lambda assign, target, src: target.assign(src.bitcast(target.dtype)).replace(tag=assign.tag)),
# if assign target is itself an ASSIGN chain, canonicalize to the original buffer target
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),
# assign only to buffer, otherwise make it a CONTIGUOUS
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.PARAM}, name="target"), UPat(name="src")), name="assign"), assign_to_contiguous),

View File

@@ -315,14 +315,7 @@ class Tensor(OpMixin):
if is_disk:
self._buffer().copyin(x._data())
return self
# chained full-buffer assign should keep writing into the original target buffer
# TODO: move this to rangeify, currently pm_remove_bufferize drops some tags
if self.uop.op is Ops.ASSIGN and (target:=self.uop.src[0]).has_buffer_identity():
if self.uop in x.uop.toposort():
# break assign-in-source cycle lazily through a temporary
result = self._apply_uop(lambda _self, val: target.assign(val.contiguous()), x)
else: result = self._apply_uop(lambda _self, val: target.assign(val), x)
else: result = self._apply_uop(UOp.assign, x)
result = self._apply_uop(UOp.assign, x)
# track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read
if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity():
# deduplicate: if the value is already a pending assign for this buffer (e.g. __iadd__ in __setitem__), remove it