mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 00:38:10 -05:00
set forced_realize for outputs [pr] (#7469)
This commit is contained in:
@@ -88,8 +88,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
||||
|
||||
def get_realizes(outs:List[LazyBuffer], ctx) -> Tuple[List[List[UOp]], Dict[Buffer, LazyBuffer], Dict[LazyBuffer, LazyBuffer]]:
|
||||
"""search the graph for all the LazyBuffers that need to realize"""
|
||||
# start by just realizing the buffers passed in
|
||||
realizes: Dict[LazyBuffer, None] = {x:None for x in outs}
|
||||
realizes: Dict[LazyBuffer, None] = {}
|
||||
allbufs: Dict[LazyBuffer, None] = {}
|
||||
simple_pads: Dict[LazyBuffer, None] = {}
|
||||
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
|
||||
|
||||
@@ -237,6 +237,7 @@ break_sched = PatternMatcher([(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b,
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not MetaOps.CONST)) == 0: return [], {}
|
||||
for out in outs: out.forced_realize = True
|
||||
# create the big graph
|
||||
ctx = ScheduleContext()
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
|
||||
Reference in New Issue
Block a user