diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index c49e492886..a95c69f719 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -55,6 +55,13 @@ def fix_assign_hazard(assign:UOp, target:UOp, src:UOp): if any(s is target.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.PARAM})): return assign.replace(src=(target, src.contiguous())) +def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp): + root_target = target + while root_target.op is Ops.ASSIGN: root_target = root_target.src[0] + # when RHS depends on the previous assign result, break with contiguous + if target in src.toposort(): src = src.contiguous() + return assign.replace(src=(root_target, src)) + def split_reduceop(reduce:UOp, x:UOp): if prod(reduce.shape) == 0: return None if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))