Revert "s/UPat/Pat (#7506)" [pr] (#7517)

* Revert "s/UPat/Pat (#7506)"

This reverts commit 400011a8c1.

* fix
This commit is contained in:
chenyu
2024-11-03 16:33:02 -05:00
committed by GitHub
parent e641bbc859
commit 7758f7211b
12 changed files with 347 additions and 347 deletions

View File

@@ -2,7 +2,7 @@ import unittest, pickle, types
import numpy as np
from tinygrad import Tensor, TinyJit, Variable, dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.ops import PatternMatcher, Pat, UOp
from tinygrad.ops import PatternMatcher, UPat, UOp
class TestPickle(unittest.TestCase):
def test_pickle_code_object(self):
@@ -12,7 +12,7 @@ class TestPickle(unittest.TestCase):
self.assertEqual(fxn(2), 4)
def test_pickle_pattern_matcher(self):
pm = PatternMatcher([(Pat.cvar('x'), lambda x: x*2)])
pm = PatternMatcher([(UPat.cvar('x'), lambda x: x*2)])
sink = UOp.const(dtypes.int, 2)
tt = pm.rewrite(sink)
pm_str = pickle.dumps(pm)

View File

@@ -3,7 +3,7 @@ import unittest, time
from tinygrad import dtypes, Device
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, Ops, UOp, KernelInfo
from tinygrad.ops import Pat, PatternMatcher
from tinygrad.ops import UPat, PatternMatcher
from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, sym
@@ -11,10 +11,10 @@ from tinygrad.codegen.linearize import linearize_uop
from tinygrad.shape.shapetracker import ShapeTracker, View
simple_pm = PatternMatcher([
(Pat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
(Pat.cvar('x') + Pat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
(Pat.cvar('x') * Pat.cvar('y') * Pat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
((Pat.var('x') + Pat.cvar('c1')) + Pat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
(UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
(UPat.cvar('x') + UPat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
(UPat.cvar('x') * UPat.cvar('y') * UPat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
])
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))

View File

@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.dtype import dtypes, DType
from tinygrad.device import Buffer, Device
from tinygrad.ops import Ops, UOp, Pat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule, to_si
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
@@ -441,13 +441,13 @@ class TestIndexingOrdering(unittest.TestCase):
stores = [st for st in uops if st.op is Ops.STORE]
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
class TestPatHelpers(unittest.TestCase):
class TestUPatHelpers(unittest.TestCase):
def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "uopgraph.py")
self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py")
with self.assertRaises(AssertionError): # TODO: location Pat files created in test/*?
test_upat = Pat(Ops.CONST, dtypes.bool)
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
test_upat = UPat(Ops.CONST, dtypes.bool)
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
if __name__ == '__main__':

View File

@@ -1,7 +1,7 @@
from typing import Dict, List, Optional
import unittest
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, Pat, \
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, \
graph_rewrite, contexts, track_rewrites
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
@@ -25,7 +25,7 @@ class TestViz(unittest.TestCase):
def test_viz_simple(self):
pm = PatternMatcher([
(Pat.var("x")*1, lambda x:x),
(UPat.var("x")*1, lambda x:x),
])
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a*1, pm)
@@ -34,8 +34,8 @@ class TestViz(unittest.TestCase):
def test_rewrite_twice(self):
pm = PatternMatcher([
(Pat.var("x")+Pat.var("x"), lambda x:x*2),
(Pat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
])
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a+a, pm)
@@ -51,14 +51,14 @@ class TestViz(unittest.TestCase):
ctx[x] = None
return UOp.store(*x.src, x)
pm = PatternMatcher([
(Pat(Ops.LOAD, name="x"), store_load),
(UPat(Ops.LOAD, name="x"), store_load),
])
uops = helper_test_viz(a+b, pm, {})
self.assertEqual(len(uops), 2)
self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {}))
def test_track_rewrites(self):
simple = PatternMatcher([(Pat.var("x")*1, lambda x:x)])
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
@track_rewrites(named=True)
def do_rewrite(x:UOp): return graph_rewrite(x, simple)
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
@@ -74,7 +74,7 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(m.upats), 0)
def test_track_rewrites_with_exception(self):
simple = PatternMatcher([(Pat.var("x")*1, lambda x:x)])
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
@track_rewrites()
def do_rewrite(x:UOp):
x = graph_rewrite(x, simple) # NOTE: viz tracks this

View File

@@ -1,11 +1,11 @@
import unittest, itertools
from tinygrad.dtype import dtypes
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
from tinygrad.ops import PatternMatcher, Pat
from tinygrad.ops import PatternMatcher, UPat
class TestPatternMatcher(unittest.TestCase):
def test_simple_match(self):
matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.int, arg=1)
self.assertEqual(matcher.rewrite(c1), c1)
@@ -16,7 +16,7 @@ class TestPatternMatcher(unittest.TestCase):
#print(x,y,z)
if y is not None: return a+y
matcher = PatternMatcher([
(Pat.var("a")+Pat.any(Pat.var("x"), Pat.var("y"), Pat.var("z")), test),
(UPat.var("a")+UPat.any(UPat.var("x"), UPat.var("y"), UPat.var("z")), test),
])
v1 = UOp.variable("a", 0, 10)
v2 = UOp.variable("b", 0, 10)
@@ -31,7 +31,7 @@ class TestPatternMatcher(unittest.TestCase):
match_cnt += 1
assert len(x.src) == 0
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
matcher = PatternMatcher([(Pat(Ops.CONST, src=(), name="x"), fxn)])
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
# second rewrite shouldn't match anything
c1 = matcher.rewrite(c1)
@@ -43,7 +43,7 @@ class TestPatternMatcher(unittest.TestCase):
ctx.append(True)
assert len(x.src) == 0
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
matcher = PatternMatcher([(Pat(Ops.CONST, src=(), name="x"), fxn)])
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
# second rewrite shouldn't match anything
ctx = []
@@ -52,14 +52,14 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(len(ctx), 1)
def test_uop(self):
matcher = PatternMatcher([(Pat(Ops.CONST, name="x"), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_uop_set(self):
matcher = PatternMatcher([(Pat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)])
matcher = PatternMatcher([(UPat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.bool, arg=False)
c2 = UOp(Ops.CAST, dtypes.int, (c1,))
c3 = UOp(Ops.CONST, dtypes.float, arg=1.0)
@@ -70,9 +70,9 @@ class TestPatternMatcher(unittest.TestCase):
def test_arg(self):
matcher = PatternMatcher([
(Pat(Ops.CONST, arg=0, name="x"), lambda x: x),
(Pat(Ops.CONST, arg=False, name="x"), lambda x: x),
(Pat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
(UPat(Ops.CONST, arg=0, name="x"), lambda x: x),
(UPat(Ops.CONST, arg=False, name="x"), lambda x: x),
(UPat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
])
c1 = UOp(Ops.CONST, dtypes.float, arg=0.0)
c2 = UOp(Ops.CONST, dtypes.bool, arg=False)
@@ -87,7 +87,7 @@ class TestPatternMatcher(unittest.TestCase):
def test_filter_arg(self):
matcher = PatternMatcher([
(Pat(Ops.ALU, arg=BinaryOps.MUL, src=[Pat(Ops.CONST, name="c"), Pat(Ops.CONST, arg=2)], name="x"),
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"),
lambda x,c: x if c.arg in {1, -1} else None)
])
y1 = UOp(Ops.CONST, dtypes.int, arg=1)
@@ -105,7 +105,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c5), c5)
def test_dup_name(self):
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST, name="y"), Pat(Ops.CONST, name="y"))), lambda x, y: x)])
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)])
y1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
y2 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c1 = UOp(Ops.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
@@ -114,14 +114,14 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c2), c1)
def test_dtype(self):
matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_dtype_set(self):
matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0)
c3 = UOp(Ops.CONST, dtypes.float16, arg=1.0)
@@ -132,7 +132,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c4), None)
def test_src_one(self):
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST), Pat(Ops.CONST))), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -140,7 +140,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c2), None)
# that CONST/ALU -> ALU/CONST rewrite is now instant
"""
matcher = PatternMatcher([(Pat(UOps.ALU, name="x", src=(Pat(UOps.CONST), Pat(UOps.ALU))), lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c3), None)
@@ -149,7 +149,7 @@ class TestPatternMatcher(unittest.TestCase):
"""
def test_src_permutations(self):
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=[Pat(Ops.CONST), Pat(Ops.ALU)]), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=[UPat(Ops.CONST), UPat(Ops.ALU)]), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -162,7 +162,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c6), None)
def test_src_repeat(self):
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=Pat(Ops.CONST)), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -171,7 +171,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c4), None)
def test_allow_len(self):
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.CONST, dtypes.float, arg=3.0)
@@ -188,16 +188,16 @@ class TestPatternMatcher(unittest.TestCase):
u1 = (c1 + c2) + c1
u2 = (c2 + c1) + c1
matcher = PatternMatcher([
(Pat(Ops.ALU, src=[Pat(Ops.ALU, src=[Pat(name='a'), Pat(name='b')]), Pat(name='b')]), lambda a,b: b)
(UPat(Ops.ALU, src=[UPat(Ops.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
])
self.assertIsNotNone(matcher.rewrite(u1))
self.assertIsNotNone(matcher.rewrite(u2))
def _assert_eq_upat(self, a:Pat, b:Pat):
def _assert_eq_upat(self, a:UPat, b:UPat):
assert (sorted(map(str,a.op)) if a.op else [] == (sorted(map(str,b.op)) if b.op else []))
assert (sorted(a.dtype) if a.dtype else [] == (sorted(b.dtype) if b.dtype else []))
assert (a.name, type(a.src)) == (b.name, type(b.src))
def simple_src(u:Pat):
def simple_src(u:UPat):
if u.src is None: return []
if isinstance(u.src, itertools.repeat): return next(u.src[0])
return u.src[0]