From 7c3ba3fa8a65d7803163a4acfe10ff51b97fa476 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 23 Aug 2024 15:28:57 -0700 Subject: [PATCH] improve match stats + custom early reject [run_process_replay] (#6260) * improve match stats [run_process_replay] * custom_early_reject --- tinygrad/codegen/uopgraph.py | 2 +- tinygrad/ops.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 621524f7b8..d856b4c81a 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -441,7 +441,7 @@ expander = PatternMatcher([ (NOp(UOps.STORE, name="root"), create_gate), # do expansion (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, - UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root"), do_expand), + UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand), (NOp(UOps.CONTRACT, name="con"), do_contract), # remove EXPANDs from SINK (NOp(UOps.SINK, name="root"), diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5680f97652..38e231d772 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -221,12 +221,14 @@ class NOp(UOp): class UPat: def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, - name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False, location=None): + name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False, location=None, + custom_early_reject:Optional[Set[Tuple[UOps, Any]]]=None): self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,)) self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,)) self.arg, self.name = arg, name self.in_src = src self.src: Any = None + self.custom_early_reject = custom_early_reject # try all permutations if it's a list if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src] # only one if it's a tuple @@ -239,6 +241,7 @@ class UPat: @functools.cached_property def early_reject(self) -> Set[Tuple[UOps, Any]]: + if self.custom_early_reject is not None: return self.custom_early_reject upat_match = [self.in_src] if isinstance(self.in_src, UPat) else ([] if self.in_src is None else self.src[0]) return set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1) @@ -291,7 +294,7 @@ class TrackedPattenMatcher(PatternMatcher): def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]): super().__init__(patterns) for p,_ in self.patterns: - if p not in match_stats: match_stats[p] = [0,0,0.0] + if p not in match_stats: match_stats[p] = [0,0,0.0,0.0] def rewrite(self, uop:UOp) -> Optional[UOp]: ret = None @@ -305,6 +308,7 @@ class TrackedPattenMatcher(PatternMatcher): if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: match_stats[p][0] += 1 match_stats[p][2] += (et:=time.perf_counter()-st) + match_stats[p][3] += et if TRACK_MATCH_STATS >= 2: print(f"{et*1e6:7.2f} us -- ", p.printable()) return ret # NOTE: if it returns None, we keep trying to match match_stats[p][2] += time.perf_counter()-st @@ -315,11 +319,12 @@ if TRACK_MATCH_STATS: import atexit @atexit.register def print_match_stats(): - ret = [0,0,0.0] + ret = [0,0,0.0,0.0] for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]): - print(f"{v[0]:6d} / {v[1]:7d} -- {v[2]*1000.:9.2f} ms -- {k.location[0].split('tinygrad/')[-1]:>20s}:{k.location[1]:<3d}", k.printable()) + loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}" + print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable()) ret = [x+y for x,y in zip(ret, v)] - print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[2]*1000.:9.2f} ms -- TOTAL") + print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL") # *** simple graph rewrite engine ***