mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
improve match stats + custom early reject [run_process_replay] (#6260)
* improve match stats [run_process_replay] * custom_early_reject
This commit is contained in:
@@ -441,7 +441,7 @@ expander = PatternMatcher([
|
||||
(NOp(UOps.STORE, name="root"), create_gate),
|
||||
# do expansion
|
||||
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
|
||||
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root"), do_expand),
|
||||
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
|
||||
(NOp(UOps.CONTRACT, name="con"), do_contract),
|
||||
# remove EXPANDs from SINK
|
||||
(NOp(UOps.SINK, name="root"),
|
||||
|
||||
@@ -221,12 +221,14 @@ class NOp(UOp):
|
||||
|
||||
class UPat:
|
||||
def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,
|
||||
name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False, location=None):
|
||||
name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False, location=None,
|
||||
custom_early_reject:Optional[Set[Tuple[UOps, Any]]]=None):
|
||||
self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,))
|
||||
self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,))
|
||||
self.arg, self.name = arg, name
|
||||
self.in_src = src
|
||||
self.src: Any = None
|
||||
self.custom_early_reject = custom_early_reject
|
||||
# try all permutations if it's a list
|
||||
if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src]
|
||||
# only one if it's a tuple
|
||||
@@ -239,6 +241,7 @@ class UPat:
|
||||
|
||||
@functools.cached_property
|
||||
def early_reject(self) -> Set[Tuple[UOps, Any]]:
|
||||
if self.custom_early_reject is not None: return self.custom_early_reject
|
||||
upat_match = [self.in_src] if isinstance(self.in_src, UPat) else ([] if self.in_src is None else self.src[0])
|
||||
return set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1)
|
||||
|
||||
@@ -291,7 +294,7 @@ class TrackedPattenMatcher(PatternMatcher):
|
||||
def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]):
|
||||
super().__init__(patterns)
|
||||
for p,_ in self.patterns:
|
||||
if p not in match_stats: match_stats[p] = [0,0,0.0]
|
||||
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
ret = None
|
||||
@@ -305,6 +308,7 @@ class TrackedPattenMatcher(PatternMatcher):
|
||||
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None:
|
||||
match_stats[p][0] += 1
|
||||
match_stats[p][2] += (et:=time.perf_counter()-st)
|
||||
match_stats[p][3] += et
|
||||
if TRACK_MATCH_STATS >= 2: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
||||
return ret # NOTE: if it returns None, we keep trying to match
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
@@ -315,11 +319,12 @@ if TRACK_MATCH_STATS:
|
||||
import atexit
|
||||
@atexit.register
|
||||
def print_match_stats():
|
||||
ret = [0,0,0.0]
|
||||
ret = [0,0,0.0,0.0]
|
||||
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]):
|
||||
print(f"{v[0]:6d} / {v[1]:7d} -- {v[2]*1000.:9.2f} ms -- {k.location[0].split('tinygrad/')[-1]:>20s}:{k.location[1]:<3d}", k.printable())
|
||||
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
|
||||
print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
|
||||
ret = [x+y for x,y in zip(ret, v)]
|
||||
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[2]*1000.:9.2f} ms -- TOTAL")
|
||||
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL")
|
||||
|
||||
# *** simple graph rewrite engine ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user