From 2c3e3559eb08da8fe70cd77012826d12b012e25f Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 9 Feb 2026 09:19:46 -0500 Subject: [PATCH] remove a contiguous in basic setitem (#14640) handled in rangeify --- tinygrad/schedule/rangeify.py | 7 +++++++ tinygrad/tensor.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 94250460e3..00d471e520 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -33,6 +33,10 @@ pm_mops = PatternMatcher([ # ***************** # 0. do some cleanup rewrites, mostly copied from the old stuff +def collapse_nested_assign(assign:UOp, target:UOp, src:UOp): + """nested ASSIGN to the same buffer (e.g. __iadd__ in __setitem__): collapse the redundant outer ASSIGN""" + if src.src[0].base is target.base: return src if src.src[0] is target else assign.replace(src=(target, src.src[1])) + def assign_to_contiguous(assign:UOp, target:UOp, src:UOp): if (t := target.base).op is Ops.BUFFER or (t.op is Ops.MSTACK and all(s.op is Ops.BUFFER for s in t.src)): return None return src.f(Ops.CONTIGUOUS, tag=assign.tag) @@ -132,6 +136,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ # ** assign rules ** + # collapse nested ASSIGN to the same buffer (e.g. __iadd__ in __setitem__) + (UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(Ops.ASSIGN, name="src")), name="assign"), collapse_nested_assign), + # move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype)) (UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src")), name="assign"), lambda assign, target, src: target.assign(src.bitcast(target.dtype)).replace(tag=assign.tag)), diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 79d9f382a6..0dcaa91097 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1291,7 +1291,7 @@ class Tensor(OpMixin): else: # basic setitem self.realize() if not self.uop.is_writable_view(): raise RuntimeError("setitem target must be a writable view backed by a buffer") - v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous() + v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)) res.assign(v).realize() def __delitem__(self, indices) -> None: