mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove handle_assign_mops [pr] (#15249)
This commit is contained in:
@@ -97,25 +97,11 @@ def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
|
||||
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
||||
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
||||
|
||||
def handle_assign_mops(ctx:IndexingContext, assign:UOp, target:UOp, src:UOp):
|
||||
if target.op in GroupOp.Movement and src.op is not Ops.CALL:
|
||||
mops = []
|
||||
while target.op in GroupOp.Movement:
|
||||
mops.append((target.op, target.marg))
|
||||
target = target.src[0]
|
||||
if mops and assign in ctx.range_map:
|
||||
ret = assign.replace(arg=tuple(mops))
|
||||
ctx.range_map[ret] = ctx.range_map[assign]
|
||||
return ret
|
||||
return None
|
||||
|
||||
pm_apply_rangeify = PatternMatcher([
|
||||
# REDUCE_AXIS -> REDUCE
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), convert_reduce_axis_to_reduce_with_ranges),
|
||||
# PAD -> WHERE
|
||||
(UPat(Ops.PAD, name="x"), convert_pad_to_where_to_keep_behavior_local),
|
||||
# store movement ops in ASSIGN arg
|
||||
(UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(name="src")), name="assign"), handle_assign_mops),
|
||||
# finally, apply_rangeify
|
||||
(UPat(GroupOp.All, name="x"), create_bufferize_and_index_based_on_ranges),
|
||||
# remove movement op
|
||||
|
||||
@@ -359,14 +359,13 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
||||
while assign_src.op is Ops.NOOP: assign_src = assign_src.src[0]
|
||||
|
||||
store_target = assign_target
|
||||
if assign.arg and assign_target.src[0].op is Ops.BUFFERIZE and assign_target.src[0].src[0].op is Ops.INDEX:
|
||||
if assign_target.src[0].op is Ops.BUFFERIZE and assign_target.src[0].src[0].op is Ops.INDEX:
|
||||
# BUFFERIZE(INDEX(...)); store through the underlying global index instead.
|
||||
store_target = assign_target.src[0].src[0]
|
||||
|
||||
end_rngs = sorted(dedup(tuple(store_target.ranges) + tuple(rngs)), key=lambda x: x.arg)
|
||||
ret = store_target.buf_uop.base
|
||||
if assign_src is not store_target: ret = ret.after(store_target.replace(dtype=sdtype).store(assign_src).end(*end_rngs))
|
||||
for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg)
|
||||
if assign_src is not assign_target: ret = ret.after(store_target.replace(dtype=sdtype).store(assign_src).end(*end_rngs))
|
||||
return ret
|
||||
|
||||
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
||||
|
||||
@@ -313,7 +313,14 @@ class Tensor(OpMixin):
|
||||
assign_uop = self.uop.assign(x.uop)
|
||||
base = self.uop.base
|
||||
if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity():
|
||||
_apply_map_to_tensors({base: base.after(assign_uop)}, name="Embed View Assign", walk=True)
|
||||
original_uop = self.uop
|
||||
assigned_base = base.after(assign_uop)
|
||||
_apply_map_to_tensors({base: assigned_base}, name="Embed View Assign", walk=True)
|
||||
def replace_view_base(u:UOp) -> UOp:
|
||||
return u.replace(src=((assigned_base if u.src[0] is base else replace_view_base(u.src[0])),)+u.src[1:])
|
||||
ret = Tensor(replace_view_base(original_uop), device=self.device, requires_grad=self.requires_grad)
|
||||
self.replace(self._apply_uop(lambda *_: assign_uop, x))
|
||||
return ret
|
||||
return self.replace(self._apply_uop(lambda *_: assign_uop, x))
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user