mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Construct UOps patterns using UPat (#4821)
* Allow UPat pattern definitions * Convert pattern matcher tests to UPat constructions * Convert constant_folder patterns to upat constructions * Convert assembly patterns to upat constructions * [run_process_replay] Drop UPat.from_dict
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp, UPat
|
||||
|
||||
class TestPatternMatcher(unittest.TestCase):
|
||||
def assert_equiv_uops(self, uop1:UOp, uop2:UOp):
|
||||
@@ -11,21 +11,21 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(uop1.arg, uop2.arg)
|
||||
|
||||
def test_simple_match(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float}, lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.int, arg=1)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_uop(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST}, lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_uop_set(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": {UOps.CONST, UOps.CAST}}, lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat({UOps.CONST, UOps.CAST}, name="x"), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
c2 = UOp(UOps.CAST, dtypes.int, (c1,))
|
||||
c3 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
@@ -36,9 +36,9 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
|
||||
def test_arg(self):
|
||||
matcher = PatternMatcher([
|
||||
({"__name__": "x", "uop": UOps.CONST, "arg": 0}, lambda x: x),
|
||||
({"__name__": "x", "uop": UOps.CONST, "arg": False}, lambda x: x),
|
||||
({"__name__": "x", "uop": UOps.ALU, "arg": BinaryOps.MAX}, lambda x: x),
|
||||
(UPat(UOps.CONST, 0, name="x"), lambda x: x),
|
||||
(UPat(UOps.CONST, False, name="x"), lambda x: x),
|
||||
(UPat(UOps.ALU, BinaryOps.MAX, name="x"), lambda x: x),
|
||||
])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=0.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
@@ -52,8 +52,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c5), None)
|
||||
|
||||
def test_arg_set(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "arg": BinaryOps.MUL,
|
||||
"vin": ({"uop": UOps.CONST, "arg": {-1, 1}}, {"uop": UOps.CONST, "arg": 2})}, lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, BinaryOps.MUL, (UPat(UOps.CONST, {-1, 1}), UPat(UOps.CONST, 2)), name="x"), lambda x: x)])
|
||||
y1 = UOp(UOps.CONST, dtypes.int, arg=1)
|
||||
y2 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
y3 = UOp(UOps.CONST, dtypes.int, arg=-1)
|
||||
@@ -65,8 +64,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c3), c3)
|
||||
|
||||
def test_dup_name(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin": ({"uop": UOps.CONST, "__name__": "y"}, {"__name__": "y"})},
|
||||
lambda x, y: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)])
|
||||
y1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
y2 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
|
||||
@@ -75,14 +73,14 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_dtype(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float32}, lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_dtype_set(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": set([dtypes.float32, dtypes.float64])}, lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=set([dtypes.float32, dtypes.float64])), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
|
||||
c3 = UOp(UOps.CONST, dtypes.float16, arg=1.0)
|
||||
@@ -93,13 +91,13 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_vin_one(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":({"uop": UOps.CONST}, {"uop": UOps.CONST})}, lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c3), c3)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":({"uop": UOps.CONST}, {"uop": UOps.ALU})}, lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
|
||||
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c3), None)
|
||||
@@ -107,7 +105,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c5), None)
|
||||
|
||||
def test_vin_permutations(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":[{"uop": UOps.CONST}, {"uop": UOps.ALU}]}, lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -120,7 +118,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c6), None)
|
||||
|
||||
def test_vin_repeat(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":{"uop": UOps.CONST}}, lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=UPat(UOps.CONST)), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -129,7 +127,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_allow_len(self):
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin": ({"uop": UOps.CONST},), "__allow_len__": {3}}, lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST),), allow_len={3}), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.CONST, dtypes.float, arg=3.0)
|
||||
@@ -144,8 +142,8 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
def test_rewrite_graph_folds(self):
|
||||
uops = UOpGraph()
|
||||
uops.add(UOps.CONST, dtypes.float, arg=2.0, simplify=False)
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float},
|
||||
lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x, x), BinaryOps.ADD),)))])
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float),
|
||||
lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x, x), BinaryOps.ADD),)))])
|
||||
matcher.rewrite_graph(uops)
|
||||
# TODO: fix this. it's 2 now
|
||||
# self.assertEqual(len(uops.uops), 1)
|
||||
@@ -156,7 +154,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
def test_rewrite_graph_adds(self):
|
||||
uops = UOpGraph()
|
||||
uops.add(UOps.CONST, dtypes.int, arg=2, simplify=False)
|
||||
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.int},
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.int),
|
||||
lambda x: UOp(UOps.STORE, x.dtype, (UOp(UOps.DEFINE_GLOBAL, x.dtype, tuple(), None), x)))])
|
||||
matcher.rewrite_graph(uops)
|
||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.STORE}))
|
||||
|
||||
@@ -77,21 +77,6 @@ class UPat:
|
||||
dtype: Optional[Union[DType, Set[DType]]] = None
|
||||
allow_len: Set[int] = field(default_factory=set)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, pat:Dict[str, Any]) -> UPat:
|
||||
name, uop, dtype = pat.get("__name__"), pat.get("uop"), pat.get("dtype")
|
||||
assert isinstance(name, str) or name is None
|
||||
assert isinstance(uop, (UOps, set)) or uop is None
|
||||
assert isinstance(dtype, (DType, set)) or dtype is None
|
||||
vin = pat.get("vin")
|
||||
if isinstance(vin, list): vin = [UPat.from_dict(x) for x in vin]
|
||||
elif isinstance(vin, tuple): vin = tuple(UPat.from_dict(x) for x in vin)
|
||||
elif isinstance(vin, dict): vin = UPat.from_dict(vin)
|
||||
else: assert vin is None
|
||||
arg = pat.get("arg")
|
||||
allow_len = pat.get("__allow_len__", set())
|
||||
return cls(uop, arg, vin, name, dtype, allow_len)
|
||||
|
||||
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
|
||||
if pat.name in store and store[pat.name] != uop: return False
|
||||
if pat.name is not None: store[pat.name] = uop
|
||||
@@ -120,12 +105,11 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
|
||||
return False
|
||||
|
||||
class PatternMatcher:
|
||||
def __init__(self, patterns:List[Tuple[Dict[str, Any], Callable]]):
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
self.patterns = patterns
|
||||
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
|
||||
# uop is required, arg is optional
|
||||
for pd,fxn in self.patterns:
|
||||
p = UPat.from_dict(pd)
|
||||
for p,fxn in self.patterns:
|
||||
assert p.uop is not None
|
||||
if isinstance(p.uop, set):
|
||||
for uop in p.uop: self.pdict[(uop, p.arg)].append((p, fxn))
|
||||
@@ -157,100 +141,100 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
|
||||
# this is symbolic 2.0
|
||||
constant_folder = PatternMatcher([
|
||||
# arange loop folding (early)
|
||||
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": (
|
||||
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin":
|
||||
[{"__name__": "idx"}, {"uop": UOps.ALU, "arg": BinaryOps.MUL,
|
||||
"vin": [{"__name__": "mval", "uop": UOps.CONST}, {"uop": UOps.RANGE, "vin": ({"__name__": "loop_start"}, {"__name__": "loop_end"})}]}]},
|
||||
{"__name__": "compval", "uop": UOps.CONST})}, {"__name__": "multconst", "uop": UOps.CONST}, {"uop": UOps.CONST, "arg": 0})}, loop_collapse),
|
||||
(UPat(UOps.ALU, TernaryOps.WHERE, vin=(UPat(UOps.ALU, BinaryOps.CMPLT, vin=(
|
||||
UPat(UOps.ALU, BinaryOps.ADD, vin=
|
||||
[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL,
|
||||
vin=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, vin=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
||||
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
|
||||
# sum collapse to mul (with possible GEP)
|
||||
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.RANGE, "__name__": "loop"},)},
|
||||
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
|
||||
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.GEP,
|
||||
"vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.RANGE, "__name__": "loop"},)},)},
|
||||
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
|
||||
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, name="phi_input", vin=(UPat(UOps.RANGE, name="loop"),)),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, vin=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
(UPat(UOps.PHI, vin=(UPat(UOps.GEP, name="phi_input",
|
||||
vin=(UPat(UOps.DEFINE_ACC, vin=(UPat(UOps.RANGE, name="loop"),)),)),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, vin=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
# deal with UNMUL
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"uop": UOps.CONST, "__name__": "c1"},
|
||||
{"uop": UOps.UNMUL, "vin": [{"uop": UOps.CONST, "__name__": "c2"}, {"__name__": "v"}]}]},
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"),
|
||||
UPat(UOps.UNMUL, vin=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
|
||||
lambda c1,c2,v: v if c1.arg == c2.arg else None),
|
||||
({"uop": UOps.UNMUL, "vin": ({"uop": UOps.CONST, "__name__": "zero", "arg": 0}, {})}, lambda zero: zero),
|
||||
({"__name__": "root", "uop": UOps.CAST, "vin": ({"uop": UOps.UNMUL, "__name__": "unmul"},)},
|
||||
(UPat(UOps.UNMUL, vin=(UPat(UOps.CONST, name="zero", arg=0), UPat())), lambda zero: zero),
|
||||
(UPat(UOps.CAST, name="root", vin=(UPat(UOps.UNMUL, name="unmul"),)),
|
||||
lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
|
||||
# max on special can go away (TODO: special should be variable, same thing applies)
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MAX, "vin": [{"__name__": "c", "uop": UOps.CONST}, {"__name__": "s", "uop": UOps.SPECIAL}]},
|
||||
(UPat(UOps.ALU, BinaryOps.MAX, [UPat(UOps.CONST, name="c"), UPat(UOps.SPECIAL, name="s")]),
|
||||
lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
|
||||
# const rules
|
||||
({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
({"__name__": "root", "uop": UOps.CAST, "vin": {"__name__": "c", "uop": UOps.CONST}}, lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
(UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
(UPat(UOps.CAST, name="root", vin=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
|
||||
({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "__name__": "acc"}, {"__name__": "acc"})}, lambda acc: UOp.const(acc.dtype, acc.arg[0])),
|
||||
({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "vin": tuple()}, {"__name__": "x"})}, lambda x: x),
|
||||
({"uop": UOps.PHI, "vin": ({"uop": UOps.CONST}, {"__name__": "x"})}, lambda x: x),
|
||||
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.const(acc.dtype, acc.arg[0])),
|
||||
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, vin=tuple()), UPat(name="x"))), lambda x: x),
|
||||
(UPat(UOps.PHI, vin=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
|
||||
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
|
||||
({"__name__": "root", "uop": UOps.DEFINE_ACC, "vin": tuple()}, lambda root: UOp.const(root.dtype, root.arg[0])),
|
||||
({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "x", "uop": UOps.CONST},)}, lambda root,x: UOp.const(root.dtype, x.arg)),
|
||||
(UPat(UOps.DEFINE_ACC, name="root", vin=tuple()), lambda root: UOp.const(root.dtype, root.arg[0])),
|
||||
(UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
|
||||
# max -2147483648
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MAX, "dtype": dtypes.int, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -2147483648}]}, lambda x: x),
|
||||
(UPat(UOps.ALU, BinaryOps.MAX, dtype=dtypes.int, vin=[UPat(name="x"), UPat(UOps.CONST, -2147483648)]), lambda x: x),
|
||||
# -(-x) -> x
|
||||
({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)})}, lambda x: x),
|
||||
(UPat(UOps.ALU, UnaryOps.NEG, (UPat(UOps.ALU, UnaryOps.NEG, (UPat(name="x"),)))), lambda x: x),
|
||||
# x+-y -> x-y
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, (UPat(name="x"), UPat(UOps.ALU, UnaryOps.NEG, name="my"))),
|
||||
lambda x, my: x-my.vin[0]),
|
||||
# -1*x -> -x
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]}, lambda x: -x),
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(name="x"), UPat(UOps.CONST, -1)]), lambda x: -x),
|
||||
# bool < False is always false, True < bool is always false
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({}, {"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": False})}, lambda x: x),
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": True}, {})},
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(), UPat(UOps.CONST, False, name="x", dtype=dtypes.bool))), lambda x: x),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.CONST, True, name="x", dtype=dtypes.bool), UPat())),
|
||||
lambda x: UOp.const(dtypes.bool, False)),
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({}, {"__name__": "val"}, {"__name__": "val"})}, lambda val: val),
|
||||
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"__name__": "gate", "uop": UOps.CONST}, {"__name__": "c0"}, {"__name__": "c1"})},
|
||||
(UPat(UOps.ALU, TernaryOps.WHERE, (UPat(), UPat(name="val"), UPat(name="val"))), lambda val: val),
|
||||
(UPat(UOps.ALU, TernaryOps.WHERE, (UPat(UOps.CONST, name="gate"), UPat(name="c0"), UPat(name="c1"))),
|
||||
lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# ** constant folding **
|
||||
({"__name__": "root", "uop": UOps.ALU, "vin": {"uop": UOps.CONST}},
|
||||
(UPat(UOps.ALU, name="root", vin=UPat(UOps.CONST)),
|
||||
lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
|
||||
# ** self folding **
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 0}]}, lambda x: x), # x+0 -> x or 0+x -> x
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 1}]}, lambda x: x), # x*1 -> x or 1*x -> x
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 0})}, lambda x: x), # x-0 -> x
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 1})}, lambda x: x), # x/1 -> x
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": -1})}, lambda x: -x), # x/-1 -> -x
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="x"), UPat(UOps.CONST, 0)]), lambda x: x), # x+0 -> x or 0+x -> x
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(name="x"), UPat(UOps.CONST, 1)]), lambda x: x), # x*1 -> x or 1*x -> x
|
||||
(UPat(UOps.ALU, BinaryOps.SUB, (UPat(name="x"), UPat(UOps.CONST, 0))), lambda x: x), # x-0 -> x
|
||||
(UPat(UOps.ALU, BinaryOps.DIV, (UPat(name="x"), UPat(UOps.CONST, 1))), lambda x: x), # x/1 -> x
|
||||
(UPat(UOps.ALU, BinaryOps.DIV, (UPat(name="x"), UPat(UOps.CONST, -1))), lambda x: -x), # x/-1 -> -x
|
||||
# ** zero folding **
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(), UPat(UOps.CONST, 0, name="c")]), lambda c: c), # x*0 -> 0 or 0*x -> 0
|
||||
(UPat(UOps.ALU, BinaryOps.SUB, (UPat(name="x"), UPat(name="x"))), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
|
||||
# ** load/store folding **
|
||||
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"},
|
||||
{"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})}, lambda buf, idx: UOp(UOps.NOOP)),
|
||||
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"),
|
||||
UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))), lambda buf, idx: UOp(UOps.NOOP)),
|
||||
# ** two stage add/sub folding **
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.ADD,
|
||||
"vin": [{"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST}]}, {"__name__": "c2", "uop": UOps.CONST}]},
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.ALU, BinaryOps.ADD,
|
||||
[UPat(name="x"), UPat(UOps.CONST, name="c1")]), UPat(UOps.CONST, name="c2")]),
|
||||
lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.SUB,
|
||||
"vin": ({"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST})}, {"__name__": "c2", "uop": UOps.CONST}]},
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.ALU, BinaryOps.SUB,
|
||||
(UPat(name="x"), UPat(UOps.CONST, name="c1"))), UPat(UOps.CONST, name="c2")]),
|
||||
lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.SUB, x.dtype, [c2.arg, c1.arg]))),
|
||||
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
||||
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.ALU, "arg": TernaryOps.WHERE,
|
||||
"vin": ({"__name__": "gate"}, {"__name__": "alt"}, {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})})},
|
||||
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.ALU, TernaryOps.WHERE,
|
||||
(UPat(name="gate"), UPat(name="alt"), UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))))),
|
||||
lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))),
|
||||
# store float4/float2 directly (remove CAST/GEP)
|
||||
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
|
||||
tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(4))})},
|
||||
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.CAST, vin=
|
||||
tuple(UPat(UOps.GEP, i, vin=(UPat(name="val"),)) for i in range(4))))),
|
||||
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
|
||||
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
|
||||
tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(2))})},
|
||||
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.CAST, vin=
|
||||
tuple(UPat(UOps.GEP, i, vin=(UPat(name="val"),)) for i in range(2))))),
|
||||
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
|
||||
# CAST-PHI-GEP -> PHI-CAST
|
||||
({"__name__": "root", "uop": UOps.CAST, "vin":
|
||||
tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(4))},
|
||||
(UPat(UOps.CAST, name="root", vin=
|
||||
tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
||||
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
|
||||
({"__name__": "root", "uop": UOps.CAST, "vin":
|
||||
tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(2))},
|
||||
(UPat(UOps.CAST, name="root", vin=
|
||||
tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
|
||||
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
|
||||
# NEG/CMPLT -> CMPLT
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)},
|
||||
{"__name__": "c", "uop": UOps.CONST, "dtype": dtypes.int})},
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.ALU, UnaryOps.NEG, (UPat(name="x"),)),
|
||||
UPat(UOps.CONST, name="c", dtype=dtypes.int))),
|
||||
lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)),
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.helpers import DEBUG
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
|
||||
from tinygrad.codegen.uops import UOpGraph, PatternMatcher, UPat
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
|
||||
def render_val(x, dtype):
|
||||
@@ -229,46 +229,46 @@ class PTXRenderer(Renderer):
|
||||
return self.render_kernel(kernel, name, bufs, c.items())
|
||||
|
||||
ptx_matcher = PatternMatcher([
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.MUL, "dtype": set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
"vin": [{"__name__": "const", "uop": UOps.CONST, "arg": set([2**i for i in range(64)])}, {"__name__": "mul"}]},
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
|
||||
lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHL)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.DIV, "dtype": set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
"vin": [{"__name__": "const", "uop": UOps.CONST, "arg": set([2**i for i in range(64)])}, {"__name__": "div"}]},
|
||||
(UPat(UOps.ALU, BinaryOps.DIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
|
||||
lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHR)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPNE, "vin": ({"dtype": dtypes.bool},{})},
|
||||
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"),
|
||||
lambda root: UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
|
||||
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD,
|
||||
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
|
||||
(UPat(UOps.ALU, BinaryOps.ADD,
|
||||
[UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
|
||||
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
|
||||
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
|
||||
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
|
||||
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
|
||||
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
|
||||
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
|
||||
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool,
|
||||
vin=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
|
||||
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, vin=(UPat(),UPat())),
|
||||
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
|
||||
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
|
||||
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g", "dtype": dtypes.int})},
|
||||
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
|
||||
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
|
||||
# ptr_ar (load/store)
|
||||
({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
|
||||
{"uop": UOps.ALU, "arg": BinaryOps.ADD,"vin":[{"__name__": "alu"}, {"__name__": "const", "uop":UOps.CONST}]})},
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, vin=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
||||
lambda root, alu, const: UOp(root.uop, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
|
||||
UOp.const(const.dtype, root.vin[0].dtype.itemsize)*const)+root.vin[2:])),
|
||||
({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
|
||||
{"__name__": "const", "uop":UOps.CONST})},
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(UOps.CONST, name="const"))),
|
||||
lambda root, const: UOp(root.uop, root.dtype, (root.vin[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, const.arg * root.vin[0].dtype.itemsize),
|
||||
)+root.vin[2:])),
|
||||
({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
|
||||
{"__name__": "alu"})}, # no const here
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(name="alu"))), # no const here
|
||||
lambda root, alu: UOp(root.uop, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, 0))+root.vin[2:])),
|
||||
|
||||
Reference in New Issue
Block a user