Add UOps pattern matcher regression tests (#4725)

* add pattern matcher regression tests

* Remove test for dtype str after rebasing

* Make test uops match type spec

* leave const const, add const alu vin test

* correct uops

* actually correct uops
This commit is contained in:
Alec Chen
2024-05-30 09:12:20 -05:00
committed by GitHub
parent c2945be0a3
commit e89bc42cc7

View File

@@ -1,6 +1,6 @@
import unittest
from tinygrad.dtype import dtypes
from tinygrad.ops import BinaryOps
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps
from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp
class TestPatternMatcher(unittest.TestCase):
@@ -17,6 +17,57 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_uop(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST}, lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_uop_set(self):
matcher = PatternMatcher([({"__name__": "x", "uop": {UOps.CONST, UOps.CAST}}, lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.bool, arg=False)
c2 = UOp(UOps.CAST, dtypes.int, (c1,))
c3 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c4 = UOp(UOps.ALU, dtypes.float, (c3, c3), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), c2)
self.assertEqual(matcher.rewrite(c4), None)
def test_arg(self):
matcher = PatternMatcher([
({"__name__": "x", "uop": UOps.CONST, "arg": 0}, lambda x: x),
({"__name__": "x", "uop": UOps.CONST, "arg": False}, lambda x: x),
({"__name__": "x", "uop": UOps.ALU, "arg": BinaryOps.MAX}, lambda x: x),
])
c1 = UOp(UOps.CONST, dtypes.float, arg=0.0)
c2 = UOp(UOps.CONST, dtypes.bool, arg=False)
c3 = UOp(UOps.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MAX)
c4 = UOp(UOps.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MUL)
c5 = UOp(UOps.CONST, dtypes.int, arg=-1)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), c2)
self.assertEqual(matcher.rewrite(c3), c3)
self.assertEqual(matcher.rewrite(c4), None)
self.assertEqual(matcher.rewrite(c5), None)
def test_dup_name(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin": ({"uop": UOps.CONST, "__name__": "y"}, {"__name__": "y"})},
lambda x, y: x)])
y1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
y2 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
c2 = UOp(UOps.ALU, dtypes.float, (y1, y2), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_dtype(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float32}, lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_dtype_set(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": set([dtypes.float32, dtypes.float64])}, lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
@@ -35,6 +86,12 @@ class TestPatternMatcher(unittest.TestCase):
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c3), c3)
self.assertEqual(matcher.rewrite(c2), None)
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":({"uop": UOps.CONST}, {"uop": UOps.ALU})}, lambda x: x)])
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c3), None)
self.assertEqual(matcher.rewrite(c4), c4)
self.assertEqual(matcher.rewrite(c5), None)
def test_vin_permutations(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":[{"uop": UOps.CONST}, {"uop": UOps.ALU}]}, lambda x: x)])
@@ -58,6 +115,18 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c3), c3)
self.assertEqual(matcher.rewrite(c4), None)
def test_allow_len(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin": ({"uop": UOps.CONST},), "__allow_len__": {3}}, lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.CONST, dtypes.float, arg=3.0)
c5 = UOp(UOps.ALU, dtypes.float, (c1,), UnaryOps.NEG)
c6 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
c7 = UOp(UOps.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC)
self.assertEqual(matcher.rewrite(c5), c5)
self.assertEqual(matcher.rewrite(c6), None)
self.assertEqual(matcher.rewrite(c7), c7)
@unittest.skip("no longer supported")
def test_rewrite_graph_folds(self):
uops = UOpGraph()