From 37cc87ea75a2c167b1797cde4bf874448b4ac53a Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 3 Aug 2024 19:20:11 +0800 Subject: [PATCH] save lines in the scheduler [run_process_replay] (#5890) --- tinygrad/engine/schedule.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index b034ad793c..282ed2bd7c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -298,14 +298,12 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): top_reduce = reduceop.base.srcs[0].base if len(children[top_reduce]) == 1: del realizes[top_reduce] - def _can_fold_reduce(r:LazyBuffer, group:Dict[LazyBuffer, None]) -> bool: - if DEBUG_ARANGE:=(getenv("DEBUG_ARANGE")): print(f"checking {r} {group=}") - if any(tr.forced_realize or tr in outs for tr in group): return False - if DEBUG_ARANGE: print(colored(f"folding {r}", "green")) - return True for r in reduce_of_const: - if _can_fold_reduce(r, group:={tr:None for tr,rop in reduce_for_op.items() if rop is r}): - for tr in group: del realizes[tr] + group = {tr:None for tr,rop in reduce_for_op.items() if rop is r} + if DEBUG_ARANGE:=(getenv("DEBUG_ARANGE")): print(f"checking {r} {group=}") + if any(tr.forced_realize or tr in outs for tr in group): continue + if DEBUG_ARANGE: print(colored(f"folding {r}", "green")) + for tr in group: del realizes[tr] output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list) for buf in realizes: