mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move assign chain fix to rangeify (#14829)
This commit is contained in:
@@ -55,6 +55,13 @@ def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
|
||||
if any(s is target.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.PARAM})):
|
||||
return assign.replace(src=(target, src.contiguous()))
|
||||
|
||||
def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp):
|
||||
root_target = target
|
||||
while root_target.op is Ops.ASSIGN: root_target = root_target.src[0]
|
||||
# when RHS depends on the previous assign result, break with contiguous
|
||||
if target in src.toposort(): src = src.contiguous()
|
||||
return assign.replace(src=(root_target, src))
|
||||
|
||||
def split_reduceop(reduce:UOp, x:UOp):
|
||||
if prod(reduce.shape) == 0: return None
|
||||
if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))<getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return None
|
||||
@@ -149,6 +156,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src")), name="assign"),
|
||||
lambda assign, target, src: target.assign(src.bitcast(target.dtype)).replace(tag=assign.tag)),
|
||||
|
||||
# if assign target is itself an ASSIGN chain, canonicalize to the original buffer target
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),
|
||||
|
||||
# assign only to buffer, otherwise make it a CONTIGUOUS
|
||||
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.PARAM}, name="target"), UPat(name="src")), name="assign"), assign_to_contiguous),
|
||||
|
||||
|
||||
@@ -315,14 +315,7 @@ class Tensor(OpMixin):
|
||||
if is_disk:
|
||||
self._buffer().copyin(x._data())
|
||||
return self
|
||||
# chained full-buffer assign should keep writing into the original target buffer
|
||||
# TODO: move this to rangeify, currently pm_remove_bufferize drops some tags
|
||||
if self.uop.op is Ops.ASSIGN and (target:=self.uop.src[0]).has_buffer_identity():
|
||||
if self.uop in x.uop.toposort():
|
||||
# break assign-in-source cycle lazily through a temporary
|
||||
result = self._apply_uop(lambda _self, val: target.assign(val.contiguous()), x)
|
||||
else: result = self._apply_uop(lambda _self, val: target.assign(val), x)
|
||||
else: result = self._apply_uop(UOp.assign, x)
|
||||
result = self._apply_uop(UOp.assign, x)
|
||||
# track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read
|
||||
if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity():
|
||||
# deduplicate: if the value is already a pending assign for this buffer (e.g. __iadd__ in __setitem__), remove it
|
||||
|
||||
Reference in New Issue
Block a user