improve match stats + custom early reject [run_process_replay] (#6260)

* improve match stats [run_process_replay]

* custom_early_reject
This commit is contained in:
George Hotz
2024-08-23 15:28:57 -07:00
committed by GitHub
parent 0b0a8829fb
commit 7c3ba3fa8a
2 changed files with 11 additions and 6 deletions

View File

@@ -441,7 +441,7 @@ expander = PatternMatcher([
(NOp(UOps.STORE, name="root"), create_gate),
# do expansion
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root"), do_expand),
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
(NOp(UOps.CONTRACT, name="con"), do_contract),
# remove EXPANDs from SINK
(NOp(UOps.SINK, name="root"),

View File

@@ -221,12 +221,14 @@ class NOp(UOp):
class UPat:
def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,
name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False, location=None):
name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False, location=None,
custom_early_reject:Optional[Set[Tuple[UOps, Any]]]=None):
self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,))
self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,))
self.arg, self.name = arg, name
self.in_src = src
self.src: Any = None
self.custom_early_reject = custom_early_reject
# try all permutations if it's a list
if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src]
# only one if it's a tuple
@@ -239,6 +241,7 @@ class UPat:
@functools.cached_property
def early_reject(self) -> Set[Tuple[UOps, Any]]:
if self.custom_early_reject is not None: return self.custom_early_reject
upat_match = [self.in_src] if isinstance(self.in_src, UPat) else ([] if self.in_src is None else self.src[0])
return set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1)
@@ -291,7 +294,7 @@ class TrackedPattenMatcher(PatternMatcher):
def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]):
super().__init__(patterns)
for p,_ in self.patterns:
if p not in match_stats: match_stats[p] = [0,0,0.0]
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
def rewrite(self, uop:UOp) -> Optional[UOp]:
ret = None
@@ -305,6 +308,7 @@ class TrackedPattenMatcher(PatternMatcher):
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None:
match_stats[p][0] += 1
match_stats[p][2] += (et:=time.perf_counter()-st)
match_stats[p][3] += et
if TRACK_MATCH_STATS >= 2: print(f"{et*1e6:7.2f} us -- ", p.printable())
return ret # NOTE: if it returns None, we keep trying to match
match_stats[p][2] += time.perf_counter()-st
@@ -315,11 +319,12 @@ if TRACK_MATCH_STATS:
import atexit
@atexit.register
def print_match_stats():
ret = [0,0,0.0]
ret = [0,0,0.0,0.0]
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]):
print(f"{v[0]:6d} / {v[1]:7d} -- {v[2]*1000.:9.2f} ms -- {k.location[0].split('tinygrad/')[-1]:>20s}:{k.location[1]:<3d}", k.printable())
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
ret = [x+y for x,y in zip(ret, v)]
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[2]*1000.:9.2f} ms -- TOTAL")
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL")
# *** simple graph rewrite engine ***