write tests for algebraic UPat

This commit is contained in:
George Hotz
2025-10-05 08:20:13 +08:00
parent 865d5796f8
commit 6538935441
2 changed files with 42 additions and 1 deletions

View File

@@ -211,5 +211,39 @@ class TestPatternMatcher(unittest.TestCase):
return u.src[0]
for a,b in zip(simple_src(a), simple_src(b)): self._assert_eq_upat(a, b)
class TestAlgebraic(unittest.TestCase):
def test_plus_0(self):
pm = PatternMatcher([
(UPat.var("x") + 0, UPat.var("x")), # x+0 -> x
])
expr = UOp.const(dtypes.int, 4)+0
print(expr)
self.assertEqual(pm.rewrite(expr), UOp.const(dtypes.int, 4))
def test_div_mul(self):
pm = PatternMatcher([
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), UPat.var("x")), # (x*x2)/x2 -> x
])
expr = UOp.const(dtypes.float, 4)/2*2
print(expr)
self.assertEqual(pm.rewrite(expr), UOp.const(dtypes.int, 4))
def test_mul_is_and(self):
pm = PatternMatcher([
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), UPat.var('x') & UPat.var('y')),
])
expr = UOp.const(dtypes.bool, True)*UOp.const(dtypes.bool, True)
print(expr)
self.assertEqual(pm.rewrite(expr), UOp.const(dtypes.bool, True)&UOp.const(dtypes.bool, True))
def test_div_neg_1(self):
pm = PatternMatcher([
(UPat.var("x") // -1, UPat.var("x") * -1), # x//-1 -> x * -1
])
expr = UOp.const(dtypes.float, 4)//-1
print(expr)
self.assertEqual(pm.rewrite(expr), UOp.const(dtypes.int, 4) * -1)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -860,11 +860,18 @@ def upat_interpret(p:UPat, fxn:Callable) -> Callable:
return None
return universal_match
def fixup_pm_function(fxn) -> Callable:
if isinstance(fxn, UPat):
# TODO: write this
raise NotImplementedError("rhs UPat is not supported")
if isinstance(fxn, tuple): return types.FunctionType(*fxn)
return fxn
class PatternMatcher:
def __init__(self, patterns:Sequence[tuple[UPat, Callable|tuple]], compiled=bool(getenv("UPAT_COMPILE", 1))):
if compiled: from tinygrad.uop.upat import upat_compile
# if this comes from a pickle, we reconstruct the lambda functions here
self.patterns:list[tuple[UPat, Callable]] = [(p,types.FunctionType(*fxn) if isinstance(fxn, tuple) else fxn) for p,fxn in patterns]
self.patterns:list[tuple[UPat, Callable]] = [(p,fixup_pm_function(fxn)) for p,fxn in patterns]
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
self.pdict: dict[Ops, list[tuple[UPat, Callable, set]]] = {}
# uop is required, arg is optional