Refactor UOps pattern matcher to UPat instead of dicts (#4791)

This commit is contained in:
Alec Chen
2024-06-01 03:55:51 -05:00
committed by GitHub
parent de8c8abbd8
commit b377db7f0d

View File

@@ -1,9 +1,9 @@
from __future__ import annotations
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union
import functools, itertools, heapq
from collections import defaultdict
from enum import Enum, auto
from dataclasses import dataclass
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, DType
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
@@ -67,45 +67,62 @@ def uop_alu_resolve(u:UOp) -> sint:
# *** simplification logic ***
def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
for k,v in pattern.items():
if k == "__name__":
if v in store and store[v] != uop: return False
store[v] = uop
elif k == "arg":
if uop.arg != v: return False
elif k == "dtype":
if isinstance(v, set):
if uop.dtype not in v: return False
elif uop.dtype != v: return False
elif k == "uop":
if isinstance(v, set):
if uop.uop not in v: return False
elif uop.uop != v: return False
elif k == "vin":
# only one if it's a tuple
# try all permutations if it's a list
# repeat if it's a dict
for vp in itertools.permutations(v) if isinstance(v, list) else ([v] if isinstance(v, tuple) else [(v,)*len(uop.vin)]):
if len(uop.vin) != len(vp) and (len(uop.vin) not in pattern.get('__allow_len__', [])): return False
new_store = store.copy()
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
for k,v in new_store.items(): store[k] = v
return True
return False
return True
@dataclass(frozen=True)
class UPat:
uop: Optional[Union[UOps, Set[UOps]]] = None
arg: Any = None
vin: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None
name: Optional[str] = None
dtype: Optional[Union[DType, Set[DType]]] = None
allow_len: Set[int] = field(default_factory=set)
@classmethod
def from_dict(cls, pat:Dict[str, Any]) -> UPat:
name, uop, dtype = pat.get("__name__"), pat.get("uop"), pat.get("dtype")
assert isinstance(name, str) or name is None
assert isinstance(uop, (UOps, set)) or uop is None
assert isinstance(dtype, (DType, set)) or dtype is None
vin = pat.get("vin")
if isinstance(vin, list): vin = [UPat.from_dict(x) for x in vin]
elif isinstance(vin, tuple): vin = tuple(UPat.from_dict(x) for x in vin)
elif isinstance(vin, dict): vin = UPat.from_dict(vin)
else: assert vin is None
arg = pat.get("arg")
allow_len = pat.get("__allow_len__", set())
return cls(uop, arg, vin, name, dtype, allow_len)
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
if pat.name in store and store[pat.name] != uop: return False
if pat.name is not None: store[pat.name] = uop
if pat.arg is not None and uop.arg != pat.arg: return False
if isinstance(pat.dtype, set) and uop.dtype not in pat.dtype: return False
if isinstance(pat.dtype, DType) and uop.dtype != pat.dtype: return False
if isinstance(pat.uop, set) and uop.uop not in pat.uop: return False
if isinstance(pat.uop, UOps) and uop.uop != pat.uop: return False
if pat.vin is None: return True
# only one if it's a tuple
# try all permutations if it's a list
# repeat if it's a dict
for vp in itertools.permutations(pat.vin) if isinstance(pat.vin,list) else ([pat.vin] if isinstance(pat.vin,tuple) else [(pat.vin,)*len(uop.vin)]):
if len(uop.vin) != len(vp) and (len(uop.vin) not in pat.allow_len): return False
new_store = store.copy()
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
for k,v in new_store.items(): store[k] = v
return True
return False
class PatternMatcher:
def __init__(self, patterns:List[Tuple[Dict[str, Any], Callable]]):
self.patterns = patterns
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[Dict[str, Any], Callable]]] = defaultdict(list)
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
# uop is required, arg is optional
for p,fxn in self.patterns:
uops = p["uop"]
if isinstance(uops, set):
for uop in uops: self.pdict[(uop, p.get("arg", None))].append((p, fxn))
for pd,fxn in self.patterns:
p = UPat.from_dict(pd)
assert p.uop is not None
if isinstance(p.uop, set):
for uop in p.uop: self.pdict[(uop, p.arg)].append((p, fxn))
else:
self.pdict[(uops, p.get("arg", None))].append((p, fxn))
self.pdict[(p.uop, p.arg)].append((p, fxn))
def rewrite(self, uop:UOp) -> Optional[UOp]:
for p,fxn in itertools.chain(self.pdict[(uop.uop, uop.arg)], self.pdict[(uop.uop, None)]):