mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update fix_store_after_hazard (#15309)
actual gate is just not CONTIGUOUS, also don't need to check against full backward_slice
This commit is contained in:
@@ -8,7 +8,7 @@ from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLI
|
||||
from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, IndexingContext, apply_movement_op
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.schedule.allreduce import create_allreduce_function
|
||||
|
||||
@@ -71,8 +71,11 @@ pm_mops = PatternMatcher([
|
||||
def fix_store_after_hazard(after:UOp, target:UOp, src:UOp):
|
||||
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
|
||||
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
|
||||
if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)):
|
||||
return after.replace(src=(after.src[0], target.store(src.contiguous())))
|
||||
base = target.base
|
||||
reaches_base: dict[UOp, bool] = {}
|
||||
for s in src.toposort(gate=lambda s: s.op is not Ops.CONTIGUOUS):
|
||||
reaches_base[s] = s is base or any(reaches_base.get(c) for c in s.src)
|
||||
if reaches_base[s] and s.op in unsafe: return after.replace(src=(after.src[0], target.store(src.contiguous())))
|
||||
|
||||
def normalize_store_after_target_chain(after:UOp, target:UOp, src:UOp):
|
||||
root_target = target
|
||||
|
||||
Reference in New Issue
Block a user