set forced_realize for outputs [pr] (#7469)

This commit is contained in:
qazal
2024-11-01 20:03:12 +02:00
committed by GitHub
parent 7c9a1d69f9
commit 6febd20fcf
2 changed files with 2 additions and 2 deletions

View File

@@ -88,8 +88,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buffer, LazyBuffer], Dict[LazyBuffer, LazyBuffer]]:
"""search the graph for all the LazyBuffers that need to realize"""
# start by just realizing the buffers passed in
realizes: Dict[LazyBuffer, None] = {x:None for x in outs}
realizes: Dict[LazyBuffer, None] = {}
allbufs: Dict[LazyBuffer, None] = {}
simple_pads: Dict[LazyBuffer, None] = {}
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)

View File

@@ -237,6 +237,7 @@ break_sched = PatternMatcher([(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b,
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not MetaOps.CONST)) == 0: return [], {}
for out in outs: out.forced_realize = True
# create the big graph
ctx = ScheduleContext()
cache: Dict[LazyBuffer, UOp] = {}