mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Refactor UOps pattern matcher to UPat instead of dicts (#4791)
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable
|
||||
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union
|
||||
import functools, itertools, heapq
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
|
||||
@@ -67,45 +67,62 @@ def uop_alu_resolve(u:UOp) -> sint:
|
||||
|
||||
# *** simplification logic ***
|
||||
|
||||
def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
|
||||
for k,v in pattern.items():
|
||||
if k == "__name__":
|
||||
if v in store and store[v] != uop: return False
|
||||
store[v] = uop
|
||||
elif k == "arg":
|
||||
if uop.arg != v: return False
|
||||
elif k == "dtype":
|
||||
if isinstance(v, set):
|
||||
if uop.dtype not in v: return False
|
||||
elif uop.dtype != v: return False
|
||||
elif k == "uop":
|
||||
if isinstance(v, set):
|
||||
if uop.uop not in v: return False
|
||||
elif uop.uop != v: return False
|
||||
elif k == "vin":
|
||||
# only one if it's a tuple
|
||||
# try all permutations if it's a list
|
||||
# repeat if it's a dict
|
||||
for vp in itertools.permutations(v) if isinstance(v, list) else ([v] if isinstance(v, tuple) else [(v,)*len(uop.vin)]):
|
||||
if len(uop.vin) != len(vp) and (len(uop.vin) not in pattern.get('__allow_len__', [])): return False
|
||||
new_store = store.copy()
|
||||
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
|
||||
for k,v in new_store.items(): store[k] = v
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
@dataclass(frozen=True)
|
||||
class UPat:
|
||||
uop: Optional[Union[UOps, Set[UOps]]] = None
|
||||
arg: Any = None
|
||||
vin: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None
|
||||
name: Optional[str] = None
|
||||
dtype: Optional[Union[DType, Set[DType]]] = None
|
||||
allow_len: Set[int] = field(default_factory=set)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, pat:Dict[str, Any]) -> UPat:
|
||||
name, uop, dtype = pat.get("__name__"), pat.get("uop"), pat.get("dtype")
|
||||
assert isinstance(name, str) or name is None
|
||||
assert isinstance(uop, (UOps, set)) or uop is None
|
||||
assert isinstance(dtype, (DType, set)) or dtype is None
|
||||
vin = pat.get("vin")
|
||||
if isinstance(vin, list): vin = [UPat.from_dict(x) for x in vin]
|
||||
elif isinstance(vin, tuple): vin = tuple(UPat.from_dict(x) for x in vin)
|
||||
elif isinstance(vin, dict): vin = UPat.from_dict(vin)
|
||||
else: assert vin is None
|
||||
arg = pat.get("arg")
|
||||
allow_len = pat.get("__allow_len__", set())
|
||||
return cls(uop, arg, vin, name, dtype, allow_len)
|
||||
|
||||
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
|
||||
if pat.name in store and store[pat.name] != uop: return False
|
||||
if pat.name is not None: store[pat.name] = uop
|
||||
if pat.arg is not None and uop.arg != pat.arg: return False
|
||||
if isinstance(pat.dtype, set) and uop.dtype not in pat.dtype: return False
|
||||
if isinstance(pat.dtype, DType) and uop.dtype != pat.dtype: return False
|
||||
if isinstance(pat.uop, set) and uop.uop not in pat.uop: return False
|
||||
if isinstance(pat.uop, UOps) and uop.uop != pat.uop: return False
|
||||
if pat.vin is None: return True
|
||||
# only one if it's a tuple
|
||||
# try all permutations if it's a list
|
||||
# repeat if it's a dict
|
||||
for vp in itertools.permutations(pat.vin) if isinstance(pat.vin,list) else ([pat.vin] if isinstance(pat.vin,tuple) else [(pat.vin,)*len(uop.vin)]):
|
||||
if len(uop.vin) != len(vp) and (len(uop.vin) not in pat.allow_len): return False
|
||||
new_store = store.copy()
|
||||
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
|
||||
for k,v in new_store.items(): store[k] = v
|
||||
return True
|
||||
return False
|
||||
|
||||
class PatternMatcher:
|
||||
def __init__(self, patterns:List[Tuple[Dict[str, Any], Callable]]):
|
||||
self.patterns = patterns
|
||||
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[Dict[str, Any], Callable]]] = defaultdict(list)
|
||||
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
|
||||
# uop is required, arg is optional
|
||||
for p,fxn in self.patterns:
|
||||
uops = p["uop"]
|
||||
if isinstance(uops, set):
|
||||
for uop in uops: self.pdict[(uop, p.get("arg", None))].append((p, fxn))
|
||||
for pd,fxn in self.patterns:
|
||||
p = UPat.from_dict(pd)
|
||||
assert p.uop is not None
|
||||
if isinstance(p.uop, set):
|
||||
for uop in p.uop: self.pdict[(uop, p.arg)].append((p, fxn))
|
||||
else:
|
||||
self.pdict[(uops, p.get("arg", None))].append((p, fxn))
|
||||
self.pdict[(p.uop, p.arg)].append((p, fxn))
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
for p,fxn in itertools.chain(self.pdict[(uop.uop, uop.arg)], self.pdict[(uop.uop, None)]):
|
||||
|
||||
Reference in New Issue
Block a user