mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
write tests for algebraic UPat
This commit is contained in:
@@ -211,5 +211,39 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
return u.src[0]
|
||||
for a,b in zip(simple_src(a), simple_src(b)): self._assert_eq_upat(a, b)
|
||||
|
||||
class TestAlgebraic(unittest.TestCase):
|
||||
def test_plus_0(self):
|
||||
pm = PatternMatcher([
|
||||
(UPat.var("x") + 0, UPat.var("x")), # x+0 -> x
|
||||
])
|
||||
expr = UOp.const(dtypes.int, 4)+0
|
||||
print(expr)
|
||||
self.assertEqual(pm.rewrite(expr), UOp.const(dtypes.int, 4))
|
||||
|
||||
def test_div_mul(self):
|
||||
pm = PatternMatcher([
|
||||
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), UPat.var("x")), # (x*x2)/x2 -> x
|
||||
])
|
||||
expr = UOp.const(dtypes.float, 4)/2*2
|
||||
print(expr)
|
||||
self.assertEqual(pm.rewrite(expr), UOp.const(dtypes.int, 4))
|
||||
|
||||
def test_mul_is_and(self):
|
||||
pm = PatternMatcher([
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), UPat.var('x') & UPat.var('y')),
|
||||
])
|
||||
expr = UOp.const(dtypes.bool, True)*UOp.const(dtypes.bool, True)
|
||||
print(expr)
|
||||
self.assertEqual(pm.rewrite(expr), UOp.const(dtypes.bool, True)&UOp.const(dtypes.bool, True))
|
||||
|
||||
def test_div_neg_1(self):
|
||||
pm = PatternMatcher([
|
||||
(UPat.var("x") // -1, UPat.var("x") * -1), # x//-1 -> x * -1
|
||||
])
|
||||
expr = UOp.const(dtypes.float, 4)//-1
|
||||
print(expr)
|
||||
self.assertEqual(pm.rewrite(expr), UOp.const(dtypes.int, 4) * -1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -860,11 +860,18 @@ def upat_interpret(p:UPat, fxn:Callable) -> Callable:
|
||||
return None
|
||||
return universal_match
|
||||
|
||||
def fixup_pm_function(fxn) -> Callable:
|
||||
if isinstance(fxn, UPat):
|
||||
# TODO: write this
|
||||
raise NotImplementedError("rhs UPat is not supported")
|
||||
if isinstance(fxn, tuple): return types.FunctionType(*fxn)
|
||||
return fxn
|
||||
|
||||
class PatternMatcher:
|
||||
def __init__(self, patterns:Sequence[tuple[UPat, Callable|tuple]], compiled=bool(getenv("UPAT_COMPILE", 1))):
|
||||
if compiled: from tinygrad.uop.upat import upat_compile
|
||||
# if this comes from a pickle, we reconstruct the lambda functions here
|
||||
self.patterns:list[tuple[UPat, Callable]] = [(p,types.FunctionType(*fxn) if isinstance(fxn, tuple) else fxn) for p,fxn in patterns]
|
||||
self.patterns:list[tuple[UPat, Callable]] = [(p,fixup_pm_function(fxn)) for p,fxn in patterns]
|
||||
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
|
||||
self.pdict: dict[Ops, list[tuple[UPat, Callable, set]]] = {}
|
||||
# uop is required, arg is optional
|
||||
|
||||
Reference in New Issue
Block a user