mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Add UOps pattern matcher regression tests (#4725)
* add pattern matcher regression tests * Remove test for dtype str after rebasing * Make test uops match type spec * leave const const, add const alu vin test * correct uops * actually correct uops
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import BinaryOps
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp
|
||||
|
||||
class TestPatternMatcher(unittest.TestCase):
|
||||
@@ -17,6 +17,57 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_uop(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST}, lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_uop_set(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": {UOps.CONST, UOps.CAST}}, lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
c2 = UOp(UOps.CAST, dtypes.int, (c1,))
|
||||
c3 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c3, c3), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), c2)
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_arg(self):
|
||||
matcher = PatternMatcher([
|
||||
({"__name__": "x", "uop": UOps.CONST, "arg": 0}, lambda x: x),
|
||||
({"__name__": "x", "uop": UOps.CONST, "arg": False}, lambda x: x),
|
||||
({"__name__": "x", "uop": UOps.ALU, "arg": BinaryOps.MAX}, lambda x: x),
|
||||
])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=0.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
c3 = UOp(UOps.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MAX)
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MUL)
|
||||
c5 = UOp(UOps.CONST, dtypes.int, arg=-1)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), c2)
|
||||
self.assertEqual(matcher.rewrite(c3), c3)
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
self.assertEqual(matcher.rewrite(c5), None)
|
||||
|
||||
def test_dup_name(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin": ({"uop": UOps.CONST, "__name__": "y"}, {"__name__": "y"})},
|
||||
lambda x, y: x)])
|
||||
y1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
y2 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
|
||||
c2 = UOp(UOps.ALU, dtypes.float, (y1, y2), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_dtype(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float32}, lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_dtype_set(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": set([dtypes.float32, dtypes.float64])}, lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
@@ -35,6 +86,12 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c3), c3)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":({"uop": UOps.CONST}, {"uop": UOps.ALU})}, lambda x: x)])
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
|
||||
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c3), None)
|
||||
self.assertEqual(matcher.rewrite(c4), c4)
|
||||
self.assertEqual(matcher.rewrite(c5), None)
|
||||
|
||||
def test_vin_permutations(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":[{"uop": UOps.CONST}, {"uop": UOps.ALU}]}, lambda x: x)])
|
||||
@@ -58,6 +115,18 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c3), c3)
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_allow_len(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin": ({"uop": UOps.CONST},), "__allow_len__": {3}}, lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.CONST, dtypes.float, arg=3.0)
|
||||
c5 = UOp(UOps.ALU, dtypes.float, (c1,), UnaryOps.NEG)
|
||||
c6 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
c7 = UOp(UOps.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC)
|
||||
self.assertEqual(matcher.rewrite(c5), c5)
|
||||
self.assertEqual(matcher.rewrite(c6), None)
|
||||
self.assertEqual(matcher.rewrite(c7), c7)
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_rewrite_graph_folds(self):
|
||||
uops = UOpGraph()
|
||||
|
||||
Reference in New Issue
Block a user