update fix_store_after_hazard (#15309)

actual gate is just not CONTIGUOUS, also don't need to check against full backward_slice
This commit is contained in:
chenyu
2026-03-16 23:55:59 -04:00
committed by GitHub
parent 575b40b93a
commit 1283b57b4e

View File

@@ -8,7 +8,7 @@ from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLI
from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, IndexingContext, apply_movement_op
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.allreduce import create_allreduce_function
@@ -71,8 +71,11 @@ pm_mops = PatternMatcher([
def fix_store_after_hazard(after:UOp, target:UOp, src:UOp):
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)):
return after.replace(src=(after.src[0], target.store(src.contiguous())))
base = target.base
reaches_base: dict[UOp, bool] = {}
for s in src.toposort(gate=lambda s: s.op is not Ops.CONTIGUOUS):
reaches_base[s] = s is base or any(reaches_base.get(c) for c in s.src)
if reaches_base[s] and s.op in unsafe: return after.replace(src=(after.src[0], target.store(src.contiguous())))
def normalize_store_after_target_chain(after:UOp, target:UOp, src:UOp):
root_target = target