remove handle_assign_mops [pr] (#15249)

This commit is contained in:
chenyu
2026-03-13 01:53:21 -04:00
committed by GitHub
parent c251fc67c5
commit 576e7f985f
3 changed files with 10 additions and 18 deletions

View File

@@ -97,25 +97,11 @@ def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
def handle_assign_mops(ctx:IndexingContext, assign:UOp, target:UOp, src:UOp):
if target.op in GroupOp.Movement and src.op is not Ops.CALL:
mops = []
while target.op in GroupOp.Movement:
mops.append((target.op, target.marg))
target = target.src[0]
if mops and assign in ctx.range_map:
ret = assign.replace(arg=tuple(mops))
ctx.range_map[ret] = ctx.range_map[assign]
return ret
return None
pm_apply_rangeify = PatternMatcher([
# REDUCE_AXIS -> REDUCE
(UPat(Ops.REDUCE_AXIS, name="x"), convert_reduce_axis_to_reduce_with_ranges),
# PAD -> WHERE
(UPat(Ops.PAD, name="x"), convert_pad_to_where_to_keep_behavior_local),
# store movement ops in ASSIGN arg
(UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(name="src")), name="assign"), handle_assign_mops),
# finally, apply_rangeify
(UPat(GroupOp.All, name="x"), create_bufferize_and_index_based_on_ranges),
# remove movement op

View File

@@ -359,14 +359,13 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
while assign_src.op is Ops.NOOP: assign_src = assign_src.src[0]
store_target = assign_target
if assign.arg and assign_target.src[0].op is Ops.BUFFERIZE and assign_target.src[0].src[0].op is Ops.INDEX:
if assign_target.src[0].op is Ops.BUFFERIZE and assign_target.src[0].src[0].op is Ops.INDEX:
# BUFFERIZE(INDEX(...)); store through the underlying global index instead.
store_target = assign_target.src[0].src[0]
end_rngs = sorted(dedup(tuple(store_target.ranges) + tuple(rngs)), key=lambda x: x.arg)
ret = store_target.buf_uop.base
if assign_src is not store_target: ret = ret.after(store_target.replace(dtype=sdtype).store(assign_src).end(*end_rngs))
for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg)
if assign_src is not assign_target: ret = ret.after(store_target.replace(dtype=sdtype).store(assign_src).end(*end_rngs))
return ret
# NOTE: the DEFINE_LOCAL needs to be disambiguated here

View File

@@ -313,7 +313,14 @@ class Tensor(OpMixin):
assign_uop = self.uop.assign(x.uop)
base = self.uop.base
if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity():
_apply_map_to_tensors({base: base.after(assign_uop)}, name="Embed View Assign", walk=True)
original_uop = self.uop
assigned_base = base.after(assign_uop)
_apply_map_to_tensors({base: assigned_base}, name="Embed View Assign", walk=True)
def replace_view_base(u:UOp) -> UOp:
return u.replace(src=((assigned_base if u.src[0] is base else replace_view_base(u.src[0])),)+u.src[1:])
ret = Tensor(replace_view_base(original_uop), device=self.device, requires_grad=self.requires_grad)
self.replace(self._apply_uop(lambda *_: assign_uop, x))
return ret
return self.replace(self._apply_uop(lambda *_: assign_uop, x))
def detach(self) -> Tensor: