diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 1ceeca13c7..fc2495ed48 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -8,7 +8,7 @@ from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLI from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify from tinygrad.codegen.opt import Opt -from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op +from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, IndexingContext, apply_movement_op from tinygrad.schedule.multi import multi_pm from tinygrad.schedule.allreduce import create_allreduce_function @@ -71,8 +71,11 @@ pm_mops = PatternMatcher([ def fix_store_after_hazard(after:UOp, target:UOp, src:UOp): # PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set()) - if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)): - return after.replace(src=(after.src[0], target.store(src.contiguous()))) + base = target.base + reaches_base: dict[UOp, bool] = {} + for s in src.toposort(gate=lambda s: s.op is not Ops.CONTIGUOUS): + reaches_base[s] = s is base or any(reaches_base.get(c) for c in s.src) + if reaches_base[s] and s.op in unsafe: return after.replace(src=(after.src[0], target.store(src.contiguous()))) def normalize_store_after_target_chain(after:UOp, target:UOp, src:UOp): root_target = target