[bug fix] nested commutative pattern _match [run_process_replay] [no_assert] (#5340)

* deep pat test

* lint

* min diff

* min lines

* nothing

* is res extra

* cleanup2

* add res back

* reduce lines

* type anno

---------

Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
kormann
2024-07-09 15:38:39 +02:00
committed by GitHub
parent e815c57039
commit 3d452195e4
2 changed files with 23 additions and 15 deletions

View File

@@ -2,7 +2,7 @@ import unittest
from test.helpers import TestUOps
from tinygrad.dtype import dtypes
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps
from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp, UPat
from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp, UPat, _match
class TestPatternMatcher(TestUOps):
def test_simple_match(self):
@@ -133,6 +133,15 @@ class TestPatternMatcher(TestUOps):
self.assertEqual(matcher.rewrite(c5), None)
self.assertEqual(matcher.rewrite(c6), c6)
def test_deep_src_permutations(self):
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
u1 = (c1 + c2) + c1
u2 = (c2 + c1) + c1
pat = UPat(UOps.ALU, src = (UPat(UOps.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')))
assert _match(u1, pat, {})
assert _match(u2, pat, {})
@unittest.skip("no longer supported")
def test_rewrite_graph_folds(self):
uops = UOpGraph()

View File

@@ -156,22 +156,22 @@ class UPat:
T = TypeVar("T")
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
if pat.name is not None and store.setdefault(pat.name, uop) is not uop: return False
if pat.arg is not None and __unmatch(pat.arg, uop.arg): return False
if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return False
if pat.op is not None and __unmatch(pat.op, uop.op): return False
if pat.src is None: return True
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
if pat.name is not None and store.setdefault(pat.name, uop) is not uop: return []
if pat.arg is not None and __unmatch(pat.arg, uop.arg): return []
if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return []
if pat.op is not None and __unmatch(pat.op, uop.op): return []
if pat.src is None: return [store]
# only one if it's a tuple
# try all permutations if it's a list
# repeat if it's a UPat
res: List[Dict[str, UOp]] = []
for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len) and not pat.allow_any_len: return False
new_store = store.copy()
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
store.update(new_store)
return True
return False
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len) and not pat.allow_any_len: return []
new_stores = [store.copy()]
for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)]
res.extend(new_stores)
return res
class PatternMatcher:
def __init__(self, patterns:List[Tuple[Union[UPat, UOp], Callable]]):
@@ -188,8 +188,7 @@ class PatternMatcher:
def rewrite(self, uop:UOp) -> Optional[UOp]:
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
store: Dict[str, UOp] = {}
if _match(uop, p, store) and (ret:=fxn(**store)) is not None: return ret # NOTE: if it returns None, we keep trying to match
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
return None
def sum_collapse(phi_input, loop, val1, val2):