mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
[bug fix] nested commutative pattern _match [run_process_replay] [no_assert] (#5340)
* deep pat test * lint * min diff * min lines * nothing * is res extra * cleanup2 * add res back * reduce lines * type anno --------- Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@ import unittest
|
||||
from test.helpers import TestUOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp, UPat
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp, UPat, _match
|
||||
|
||||
class TestPatternMatcher(TestUOps):
|
||||
def test_simple_match(self):
|
||||
@@ -133,6 +133,15 @@ class TestPatternMatcher(TestUOps):
|
||||
self.assertEqual(matcher.rewrite(c5), None)
|
||||
self.assertEqual(matcher.rewrite(c6), c6)
|
||||
|
||||
def test_deep_src_permutations(self):
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
u1 = (c1 + c2) + c1
|
||||
u2 = (c2 + c1) + c1
|
||||
pat = UPat(UOps.ALU, src = (UPat(UOps.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')))
|
||||
assert _match(u1, pat, {})
|
||||
assert _match(u2, pat, {})
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_rewrite_graph_folds(self):
|
||||
uops = UOpGraph()
|
||||
|
||||
@@ -156,22 +156,22 @@ class UPat:
|
||||
T = TypeVar("T")
|
||||
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1
|
||||
|
||||
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
|
||||
if pat.name is not None and store.setdefault(pat.name, uop) is not uop: return False
|
||||
if pat.arg is not None and __unmatch(pat.arg, uop.arg): return False
|
||||
if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return False
|
||||
if pat.op is not None and __unmatch(pat.op, uop.op): return False
|
||||
if pat.src is None: return True
|
||||
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
if pat.name is not None and store.setdefault(pat.name, uop) is not uop: return []
|
||||
if pat.arg is not None and __unmatch(pat.arg, uop.arg): return []
|
||||
if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return []
|
||||
if pat.op is not None and __unmatch(pat.op, uop.op): return []
|
||||
if pat.src is None: return [store]
|
||||
# only one if it's a tuple
|
||||
# try all permutations if it's a list
|
||||
# repeat if it's a UPat
|
||||
res: List[Dict[str, UOp]] = []
|
||||
for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
|
||||
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len) and not pat.allow_any_len: return False
|
||||
new_store = store.copy()
|
||||
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
|
||||
store.update(new_store)
|
||||
return True
|
||||
return False
|
||||
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len) and not pat.allow_any_len: return []
|
||||
new_stores = [store.copy()]
|
||||
for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)]
|
||||
res.extend(new_stores)
|
||||
return res
|
||||
|
||||
class PatternMatcher:
|
||||
def __init__(self, patterns:List[Tuple[Union[UPat, UOp], Callable]]):
|
||||
@@ -188,8 +188,7 @@ class PatternMatcher:
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
||||
store: Dict[str, UOp] = {}
|
||||
if _match(uop, p, store) and (ret:=fxn(**store)) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
||||
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
||||
return None
|
||||
|
||||
def sum_collapse(phi_input, loop, val1, val2):
|
||||
|
||||
Reference in New Issue
Block a user