From c3168952f0feb559c5fc1e4a8d7c1d8179a8facf Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 21 Aug 2024 11:57:26 -0700 Subject: [PATCH] wip: tracking pattern matcher [run_process_replay] (#6225) * wip: tracking pattern matcher * better * proper dedup * timing * early reject * mergable match stats * TrackedPattenMatcher * fix TrackedPattenMatcher * cleanups * clean that too * remove early_reject * Revert "remove early_reject" This reverts commit dc2aef14b8f5da58f5ec9566daf252513cac394c. * total * sort by time * match_stats cleanup --- test/test_pattern_matcher.py | 21 ---------- tinygrad/ops.py | 80 +++++++++++++++++++++++++++++------- 2 files changed, 66 insertions(+), 35 deletions(-) diff --git a/test/test_pattern_matcher.py b/test/test_pattern_matcher.py index 845801f68d..0a2df96ce3 100644 --- a/test/test_pattern_matcher.py +++ b/test/test_pattern_matcher.py @@ -2,7 +2,6 @@ import unittest, itertools from test.helpers import TestUOps from tinygrad.dtype import dtypes from tinygrad.ops import UOps, UOp, PatternMatcher, UPat, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401 -from tinygrad.codegen.uopgraph import constant_folder class TestPatternMatcher(TestUOps): def test_simple_match(self): @@ -161,25 +160,5 @@ class TestPatternMatcher(TestUOps): return u.src[0] for a,b in zip(simple_src(a), simple_src(b)): self._assert_eq_upat(a, b) - def test_upat_str(self): - dtypes._float2 = dtypes.float.vec(2) - dtypes._float4 = dtypes.float.vec(4) - dtypes._float8 = dtypes.float.vec(8) - dtypes._float16 = dtypes.float.vec(16) - dtypes._half2 = dtypes.half.vec(2) - dtypes._half4 = dtypes.half.vec(4) - dtypes._half8 = dtypes.half.vec(8) - dtypes._half16 = dtypes.half.vec(16) - upat = UPat(UOps.CONST, name="x", dtype=dtypes.float) - assert str(upat) == str(eval(str(upat))) - evpat:UPat = eval(repr(UPat(src = [UPat(name='a'), UPat(name='b')]))) - assert len(evpat.src) == 2 - for i in range(20): upat = UPat(UOps.ALU, name="x", src=[upat, upat], arg=BinaryOps.ADD) - assert len(str(upat)) < 10_000 - assert str(eval(str(upat))) == str(upat) - for rules in constant_folder.pdict.values(): - for pat,_ in rules: - self._assert_eq_upat(pat, eval(str(pat))) - if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index a0d6edaf7b..f7cd4589f3 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,11 +1,11 @@ from __future__ import annotations +from typing import Any, DefaultDict, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Sequence +import sys, time, math, operator, ctypes, struct, functools, hashlib, itertools from collections import defaultdict -from typing import Any, DefaultDict, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING -import math, operator, ctypes, struct, functools, hashlib, itertools from enum import Enum, auto -from dataclasses import dataclass +from dataclasses import dataclass, field from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType -from tinygrad.helpers import pretty_print, prod +from tinygrad.helpers import pretty_print, prod, getenv from tinygrad.shape.symbolic import Variable, sint if TYPE_CHECKING: from tinygrad.shape.shapetracker import ShapeTracker @@ -193,24 +193,34 @@ class KernelInfo: # ***** pattern matcher ***** +def get_location() -> Tuple[str, int]: + frm = sys._getframe(1) + # no matchers in ops.py, find the real frame + while (frm.f_code.co_filename.endswith("/ops.py") or frm.f_code.co_filename == '') and frm.f_back is not None: frm = frm.f_back + return frm.f_code.co_filename, frm.f_lineno + @dataclass(frozen=True, repr=False) # reuse repr from UOp class NOp(UOp): name: Optional[str] = None src: Tuple[NOp, ...] = tuple() allow_any_len: bool = False + location: Tuple[str, int] = field(default_factory=get_location) + @staticmethod def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name) @staticmethod def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name) def const(self:Union[UOp, DType, None], b:ConstType|Variable): return NOp((x:=UOp.const(self, b)).op, x.dtype, x.src, x.arg) - def compile(self:NOp, name:Optional[str]=None) -> UPat: - return UPat(name=self.name, dtype=self.dtype) if self.op is UOps.NOOP else UPat(self.op, self.arg, (list if self.commutative() - else tuple)(src.compile() for src in self.src) or None, self.name or name, self.dtype, self.allow_any_len) + @functools.cached_property + def upat(self:NOp) -> UPat: + return UPat(name=self.name, dtype=self.dtype, location=self.location) if self.op is UOps.NOOP else \ + UPat(self.op, self.arg, (list if self.commutative() else tuple)(src.upat for src in self.src) or None, self.name, + self.dtype, self.allow_any_len, location=self.location) class UPat: def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, - name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False): + name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False, location=None): self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,)) self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,)) self.arg, self.name = arg, name @@ -223,6 +233,12 @@ class UPat: elif isinstance(src, UPat): self.src = [itertools.repeat(src)] self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src) + self.location = location or get_location() + + @functools.cached_property + def early_reject(self): + # TODO: this can be improved to support some allowed_len == 0 patterns + return set((pp.op[0], pp.arg) for pp in self.src[0] if pp.op is not None and len(pp.op) == 1) if self.allowed_len else set() def __repr__(self): def rep(x): @@ -246,23 +262,59 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: return res class PatternMatcher: - def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]): - self.patterns = patterns - self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list) + def __init__(self, patterns:Sequence[Tuple[Union[UPat, NOp], Callable]]): + self.patterns = [(p.upat if isinstance(p, NOp) else p, fxn) for p,fxn in patterns] + self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = defaultdict(list) # uop is required, arg is optional for p,fxn in self.patterns: - if isinstance(p, NOp): p = p.compile() assert p.op is not None - for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn)) + for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn, p.early_reject)) @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns) def rewrite(self, uop:UOp) -> Optional[UOp]: - for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]): + ler = set([(u.op, u.arg) for u in uop.src] + [(u.op, None) for u in uop.src]) + for p,fxn,early_reject in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]): + if not early_reject.issubset(ler): continue if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match return None +if getenv("TRACK_MATCH_STATS", 0): + match_stats = dict() + class TrackedPattenMatcher(PatternMatcher): + def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]): + super().__init__(patterns) + for p,_ in self.patterns: + if p not in match_stats: match_stats[p] = [0,0,0.0] + + def rewrite(self, uop:UOp) -> Optional[UOp]: + ret = None + ler = set([(u.op, u.arg) for u in uop.src] + [(u.op, None) for u in uop.src]) + for p,fxn,early_reject in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]): + st = time.perf_counter() + if not early_reject.issubset(ler): + match_stats[p][2] += time.perf_counter()-st + continue + match_stats[p][1] += 1 + if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: + match_stats[p][0] += 1 + match_stats[p][2] += time.perf_counter()-st + return ret # NOTE: if it returns None, we keep trying to match + match_stats[p][2] += time.perf_counter()-st + return None + PatternMatcher = TrackedPattenMatcher # type: ignore + import atexit + @functools.lru_cache(None) + def lines(fn): return open(fn).readlines() + @atexit.register + def print_match_stats(): + ret = [0,0,0.0] + for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]): + print(f"{v[0]:6d} / {v[1]:7d} -- {v[2]*1000.:9.2f} ms -- {k.location}", lines(k.location[0])[k.location[1]-1].strip()) + ret = [x+y for x,y in zip(ret, v)] + print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[2]*1000.:9.2f} ms -- TOTAL") + def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: nodes: Dict[Tuple, UOp] = {} replace: Dict[UOp, UOp] = {}