mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* Revert "s/UPat/Pat (#7506)"
This reverts commit 400011a8c1.
* fix
This commit is contained in:
@@ -2,7 +2,7 @@ import unittest, pickle, types
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit, Variable, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.ops import PatternMatcher, Pat, UOp
|
||||
from tinygrad.ops import PatternMatcher, UPat, UOp
|
||||
|
||||
class TestPickle(unittest.TestCase):
|
||||
def test_pickle_code_object(self):
|
||||
@@ -12,7 +12,7 @@ class TestPickle(unittest.TestCase):
|
||||
self.assertEqual(fxn(2), 4)
|
||||
|
||||
def test_pickle_pattern_matcher(self):
|
||||
pm = PatternMatcher([(Pat.cvar('x'), lambda x: x*2)])
|
||||
pm = PatternMatcher([(UPat.cvar('x'), lambda x: x*2)])
|
||||
sink = UOp.const(dtypes.int, 2)
|
||||
tt = pm.rewrite(sink)
|
||||
pm_str = pickle.dumps(pm)
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest, time
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, Ops, UOp, KernelInfo
|
||||
from tinygrad.ops import Pat, PatternMatcher
|
||||
from tinygrad.ops import UPat, PatternMatcher
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, sym
|
||||
@@ -11,10 +11,10 @@ from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
(Pat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
||||
(Pat.cvar('x') + Pat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
|
||||
(Pat.cvar('x') * Pat.cvar('y') * Pat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
|
||||
((Pat.var('x') + Pat.cvar('c1')) + Pat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
|
||||
(UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
||||
(UPat.cvar('x') + UPat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
|
||||
(UPat.cvar('x') * UPat.cvar('y') * UPat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
|
||||
((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
|
||||
])
|
||||
|
||||
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Context
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import Ops, UOp, Pat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.engine.schedule import create_schedule, to_si
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
@@ -441,13 +441,13 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||
stores = [st for st in uops if st.op is Ops.STORE]
|
||||
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||
|
||||
class TestPatHelpers(unittest.TestCase):
|
||||
class TestUPatHelpers(unittest.TestCase):
|
||||
def test_location(self):
|
||||
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "uopgraph.py")
|
||||
self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
|
||||
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py")
|
||||
with self.assertRaises(AssertionError): # TODO: location Pat files created in test/*?
|
||||
test_upat = Pat(Ops.CONST, dtypes.bool)
|
||||
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
|
||||
test_upat = UPat(Ops.CONST, dtypes.bool)
|
||||
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, List, Optional
|
||||
import unittest
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, Pat, \
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, \
|
||||
graph_rewrite, contexts, track_rewrites
|
||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
|
||||
|
||||
@@ -25,7 +25,7 @@ class TestViz(unittest.TestCase):
|
||||
|
||||
def test_viz_simple(self):
|
||||
pm = PatternMatcher([
|
||||
(Pat.var("x")*1, lambda x:x),
|
||||
(UPat.var("x")*1, lambda x:x),
|
||||
])
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
uops = helper_test_viz(a*1, pm)
|
||||
@@ -34,8 +34,8 @@ class TestViz(unittest.TestCase):
|
||||
|
||||
def test_rewrite_twice(self):
|
||||
pm = PatternMatcher([
|
||||
(Pat.var("x")+Pat.var("x"), lambda x:x*2),
|
||||
(Pat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
|
||||
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
|
||||
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
|
||||
])
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
uops = helper_test_viz(a+a, pm)
|
||||
@@ -51,14 +51,14 @@ class TestViz(unittest.TestCase):
|
||||
ctx[x] = None
|
||||
return UOp.store(*x.src, x)
|
||||
pm = PatternMatcher([
|
||||
(Pat(Ops.LOAD, name="x"), store_load),
|
||||
(UPat(Ops.LOAD, name="x"), store_load),
|
||||
])
|
||||
uops = helper_test_viz(a+b, pm, {})
|
||||
self.assertEqual(len(uops), 2)
|
||||
self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {}))
|
||||
|
||||
def test_track_rewrites(self):
|
||||
simple = PatternMatcher([(Pat.var("x")*1, lambda x:x)])
|
||||
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
@track_rewrites(named=True)
|
||||
def do_rewrite(x:UOp): return graph_rewrite(x, simple)
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
||||
@@ -74,7 +74,7 @@ class TestViz(unittest.TestCase):
|
||||
self.assertEqual(len(m.upats), 0)
|
||||
|
||||
def test_track_rewrites_with_exception(self):
|
||||
simple = PatternMatcher([(Pat.var("x")*1, lambda x:x)])
|
||||
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
@track_rewrites()
|
||||
def do_rewrite(x:UOp):
|
||||
x = graph_rewrite(x, simple) # NOTE: viz tracks this
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import unittest, itertools
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
|
||||
from tinygrad.ops import PatternMatcher, Pat
|
||||
from tinygrad.ops import PatternMatcher, UPat
|
||||
|
||||
class TestPatternMatcher(unittest.TestCase):
|
||||
def test_simple_match(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.int, arg=1)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
@@ -16,7 +16,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
#print(x,y,z)
|
||||
if y is not None: return a+y
|
||||
matcher = PatternMatcher([
|
||||
(Pat.var("a")+Pat.any(Pat.var("x"), Pat.var("y"), Pat.var("z")), test),
|
||||
(UPat.var("a")+UPat.any(UPat.var("x"), UPat.var("y"), UPat.var("z")), test),
|
||||
])
|
||||
v1 = UOp.variable("a", 0, 10)
|
||||
v2 = UOp.variable("b", 0, 10)
|
||||
@@ -31,7 +31,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
match_cnt += 1
|
||||
assert len(x.src) == 0
|
||||
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
|
||||
matcher = PatternMatcher([(Pat(Ops.CONST, src=(), name="x"), fxn)])
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
# second rewrite shouldn't match anything
|
||||
c1 = matcher.rewrite(c1)
|
||||
@@ -43,7 +43,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
ctx.append(True)
|
||||
assert len(x.src) == 0
|
||||
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
|
||||
matcher = PatternMatcher([(Pat(Ops.CONST, src=(), name="x"), fxn)])
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
# second rewrite shouldn't match anything
|
||||
ctx = []
|
||||
@@ -52,14 +52,14 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(len(ctx), 1)
|
||||
|
||||
def test_uop(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.CONST, name="x"), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_uop_set(self):
|
||||
matcher = PatternMatcher([(Pat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.bool, arg=False)
|
||||
c2 = UOp(Ops.CAST, dtypes.int, (c1,))
|
||||
c3 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
@@ -70,9 +70,9 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
|
||||
def test_arg(self):
|
||||
matcher = PatternMatcher([
|
||||
(Pat(Ops.CONST, arg=0, name="x"), lambda x: x),
|
||||
(Pat(Ops.CONST, arg=False, name="x"), lambda x: x),
|
||||
(Pat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
|
||||
(UPat(Ops.CONST, arg=0, name="x"), lambda x: x),
|
||||
(UPat(Ops.CONST, arg=False, name="x"), lambda x: x),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
|
||||
])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=0.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.bool, arg=False)
|
||||
@@ -87,7 +87,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
|
||||
def test_filter_arg(self):
|
||||
matcher = PatternMatcher([
|
||||
(Pat(Ops.ALU, arg=BinaryOps.MUL, src=[Pat(Ops.CONST, name="c"), Pat(Ops.CONST, arg=2)], name="x"),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"),
|
||||
lambda x,c: x if c.arg in {1, -1} else None)
|
||||
])
|
||||
y1 = UOp(Ops.CONST, dtypes.int, arg=1)
|
||||
@@ -105,7 +105,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c5), c5)
|
||||
|
||||
def test_dup_name(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST, name="y"), Pat(Ops.CONST, name="y"))), lambda x, y: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)])
|
||||
y1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
y2 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c1 = UOp(Ops.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
|
||||
@@ -114,14 +114,14 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c2), c1)
|
||||
|
||||
def test_dtype(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_dtype_set(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0)
|
||||
c3 = UOp(Ops.CONST, dtypes.float16, arg=1.0)
|
||||
@@ -132,7 +132,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_src_one(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST), Pat(Ops.CONST))), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -140,7 +140,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
# that CONST/ALU -> ALU/CONST rewrite is now instant
|
||||
"""
|
||||
matcher = PatternMatcher([(Pat(UOps.ALU, name="x", src=(Pat(UOps.CONST), Pat(UOps.ALU))), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
|
||||
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c3), None)
|
||||
@@ -149,7 +149,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_src_permutations(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=[Pat(Ops.CONST), Pat(Ops.ALU)]), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=[UPat(Ops.CONST), UPat(Ops.ALU)]), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -162,7 +162,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c6), None)
|
||||
|
||||
def test_src_repeat(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=Pat(Ops.CONST)), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -171,7 +171,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_allow_len(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.CONST, dtypes.float, arg=3.0)
|
||||
@@ -188,16 +188,16 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
u1 = (c1 + c2) + c1
|
||||
u2 = (c2 + c1) + c1
|
||||
matcher = PatternMatcher([
|
||||
(Pat(Ops.ALU, src=[Pat(Ops.ALU, src=[Pat(name='a'), Pat(name='b')]), Pat(name='b')]), lambda a,b: b)
|
||||
(UPat(Ops.ALU, src=[UPat(Ops.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
|
||||
])
|
||||
self.assertIsNotNone(matcher.rewrite(u1))
|
||||
self.assertIsNotNone(matcher.rewrite(u2))
|
||||
|
||||
def _assert_eq_upat(self, a:Pat, b:Pat):
|
||||
def _assert_eq_upat(self, a:UPat, b:UPat):
|
||||
assert (sorted(map(str,a.op)) if a.op else [] == (sorted(map(str,b.op)) if b.op else []))
|
||||
assert (sorted(a.dtype) if a.dtype else [] == (sorted(b.dtype) if b.dtype else []))
|
||||
assert (a.name, type(a.src)) == (b.name, type(b.src))
|
||||
def simple_src(u:Pat):
|
||||
def simple_src(u:UPat):
|
||||
if u.src is None: return []
|
||||
if isinstance(u.src, itertools.repeat): return next(u.src[0])
|
||||
return u.src[0]
|
||||
|
||||
Reference in New Issue
Block a user