remove a contiguous in basic setitem (#14640)

handled in rangeify
This commit is contained in:
chenyu
2026-02-09 09:19:46 -05:00
committed by GitHub
parent 6c0c8e2ac3
commit 2c3e3559eb
2 changed files with 8 additions and 1 deletions

View File

@@ -33,6 +33,10 @@ pm_mops = PatternMatcher([
# *****************
# 0. do some cleanup rewrites, mostly copied from the old stuff
def collapse_nested_assign(assign:UOp, target:UOp, src:UOp):
"""nested ASSIGN to the same buffer (e.g. __iadd__ in __setitem__): collapse the redundant outer ASSIGN"""
if src.src[0].base is target.base: return src if src.src[0] is target else assign.replace(src=(target, src.src[1]))
def assign_to_contiguous(assign:UOp, target:UOp, src:UOp):
if (t := target.base).op is Ops.BUFFER or (t.op is Ops.MSTACK and all(s.op is Ops.BUFFER for s in t.src)): return None
return src.f(Ops.CONTIGUOUS, tag=assign.tag)
@@ -132,6 +136,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# ** assign rules **
# collapse nested ASSIGN to the same buffer (e.g. __iadd__ in __setitem__)
(UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(Ops.ASSIGN, name="src")), name="assign"), collapse_nested_assign),
# move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype))
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src")), name="assign"),
lambda assign, target, src: target.assign(src.bitcast(target.dtype)).replace(tag=assign.tag)),

View File

@@ -1291,7 +1291,7 @@ class Tensor(OpMixin):
else: # basic setitem
self.realize()
if not self.uop.is_writable_view(): raise RuntimeError("setitem target must be a writable view backed by a buffer")
v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous()
v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape))
res.assign(v).realize()
def __delitem__(self, indices) -> None: