Construct UOps patterns using UPat (#4821)

* Allow UPat pattern definitions

* Convert pattern matcher tests to UPat constructions

* Convert constant_folder patterns to upat constructions

* Convert assembly patterns to upat constructions

* [run_process_replay] Drop UPat.from_dict
This commit is contained in:
Alec Chen
2024-06-05 03:29:37 -05:00
committed by GitHub
parent e47277d18a
commit 5ac30c29d8
3 changed files with 100 additions and 118 deletions

View File

@@ -1,7 +1,7 @@
import unittest
from tinygrad.dtype import dtypes
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps
from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp
from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp, UPat
class TestPatternMatcher(unittest.TestCase):
def assert_equiv_uops(self, uop1:UOp, uop2:UOp):
@@ -11,21 +11,21 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(uop1.arg, uop2.arg)
def test_simple_match(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float}, lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.int, arg=1)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_uop(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST}, lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_uop_set(self):
matcher = PatternMatcher([({"__name__": "x", "uop": {UOps.CONST, UOps.CAST}}, lambda x: x)])
matcher = PatternMatcher([(UPat({UOps.CONST, UOps.CAST}, name="x"), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.bool, arg=False)
c2 = UOp(UOps.CAST, dtypes.int, (c1,))
c3 = UOp(UOps.CONST, dtypes.float, arg=1.0)
@@ -36,9 +36,9 @@ class TestPatternMatcher(unittest.TestCase):
def test_arg(self):
matcher = PatternMatcher([
({"__name__": "x", "uop": UOps.CONST, "arg": 0}, lambda x: x),
({"__name__": "x", "uop": UOps.CONST, "arg": False}, lambda x: x),
({"__name__": "x", "uop": UOps.ALU, "arg": BinaryOps.MAX}, lambda x: x),
(UPat(UOps.CONST, 0, name="x"), lambda x: x),
(UPat(UOps.CONST, False, name="x"), lambda x: x),
(UPat(UOps.ALU, BinaryOps.MAX, name="x"), lambda x: x),
])
c1 = UOp(UOps.CONST, dtypes.float, arg=0.0)
c2 = UOp(UOps.CONST, dtypes.bool, arg=False)
@@ -52,8 +52,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c5), None)
def test_arg_set(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "arg": BinaryOps.MUL,
"vin": ({"uop": UOps.CONST, "arg": {-1, 1}}, {"uop": UOps.CONST, "arg": 2})}, lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, BinaryOps.MUL, (UPat(UOps.CONST, {-1, 1}), UPat(UOps.CONST, 2)), name="x"), lambda x: x)])
y1 = UOp(UOps.CONST, dtypes.int, arg=1)
y2 = UOp(UOps.CONST, dtypes.int, arg=2)
y3 = UOp(UOps.CONST, dtypes.int, arg=-1)
@@ -65,8 +64,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c3), c3)
def test_dup_name(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin": ({"uop": UOps.CONST, "__name__": "y"}, {"__name__": "y"})},
lambda x, y: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)])
y1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
y2 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
@@ -75,14 +73,14 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c2), None)
def test_dtype(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float32}, lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_dtype_set(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": set([dtypes.float32, dtypes.float64])}, lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=set([dtypes.float32, dtypes.float64])), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
c3 = UOp(UOps.CONST, dtypes.float16, arg=1.0)
@@ -93,13 +91,13 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c4), None)
def test_vin_one(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":({"uop": UOps.CONST}, {"uop": UOps.CONST})}, lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c3), c3)
self.assertEqual(matcher.rewrite(c2), None)
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":({"uop": UOps.CONST}, {"uop": UOps.ALU})}, lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c3), None)
@@ -107,7 +105,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c5), None)
def test_vin_permutations(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":[{"uop": UOps.CONST}, {"uop": UOps.ALU}]}, lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -120,7 +118,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c6), None)
def test_vin_repeat(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin":{"uop": UOps.CONST}}, lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=UPat(UOps.CONST)), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -129,7 +127,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c4), None)
def test_allow_len(self):
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.ALU, "vin": ({"uop": UOps.CONST},), "__allow_len__": {3}}, lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST),), allow_len={3}), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.CONST, dtypes.float, arg=3.0)
@@ -144,8 +142,8 @@ class TestPatternMatcher(unittest.TestCase):
def test_rewrite_graph_folds(self):
uops = UOpGraph()
uops.add(UOps.CONST, dtypes.float, arg=2.0, simplify=False)
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.float},
lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x, x), BinaryOps.ADD),)))])
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float),
lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x, x), BinaryOps.ADD),)))])
matcher.rewrite_graph(uops)
# TODO: fix this. it's 2 now
# self.assertEqual(len(uops.uops), 1)
@@ -156,7 +154,7 @@ class TestPatternMatcher(unittest.TestCase):
def test_rewrite_graph_adds(self):
uops = UOpGraph()
uops.add(UOps.CONST, dtypes.int, arg=2, simplify=False)
matcher = PatternMatcher([({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.int},
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.int),
lambda x: UOp(UOps.STORE, x.dtype, (UOp(UOps.DEFINE_GLOBAL, x.dtype, tuple(), None), x)))])
matcher.rewrite_graph(uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.STORE}))

View File

@@ -77,21 +77,6 @@ class UPat:
dtype: Optional[Union[DType, Set[DType]]] = None
allow_len: Set[int] = field(default_factory=set)
@classmethod
def from_dict(cls, pat:Dict[str, Any]) -> UPat:
name, uop, dtype = pat.get("__name__"), pat.get("uop"), pat.get("dtype")
assert isinstance(name, str) or name is None
assert isinstance(uop, (UOps, set)) or uop is None
assert isinstance(dtype, (DType, set)) or dtype is None
vin = pat.get("vin")
if isinstance(vin, list): vin = [UPat.from_dict(x) for x in vin]
elif isinstance(vin, tuple): vin = tuple(UPat.from_dict(x) for x in vin)
elif isinstance(vin, dict): vin = UPat.from_dict(vin)
else: assert vin is None
arg = pat.get("arg")
allow_len = pat.get("__allow_len__", set())
return cls(uop, arg, vin, name, dtype, allow_len)
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
if pat.name in store and store[pat.name] != uop: return False
if pat.name is not None: store[pat.name] = uop
@@ -120,12 +105,11 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
return False
class PatternMatcher:
def __init__(self, patterns:List[Tuple[Dict[str, Any], Callable]]):
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
self.patterns = patterns
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
# uop is required, arg is optional
for pd,fxn in self.patterns:
p = UPat.from_dict(pd)
for p,fxn in self.patterns:
assert p.uop is not None
if isinstance(p.uop, set):
for uop in p.uop: self.pdict[(uop, p.arg)].append((p, fxn))
@@ -157,100 +141,100 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
# this is symbolic 2.0
constant_folder = PatternMatcher([
# arange loop folding (early)
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": (
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin":
[{"__name__": "idx"}, {"uop": UOps.ALU, "arg": BinaryOps.MUL,
"vin": [{"__name__": "mval", "uop": UOps.CONST}, {"uop": UOps.RANGE, "vin": ({"__name__": "loop_start"}, {"__name__": "loop_end"})}]}]},
{"__name__": "compval", "uop": UOps.CONST})}, {"__name__": "multconst", "uop": UOps.CONST}, {"uop": UOps.CONST, "arg": 0})}, loop_collapse),
(UPat(UOps.ALU, TernaryOps.WHERE, vin=(UPat(UOps.ALU, BinaryOps.CMPLT, vin=(
UPat(UOps.ALU, BinaryOps.ADD, vin=
[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL,
vin=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, vin=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
# sum collapse to mul (with possible GEP)
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.RANGE, "__name__": "loop"},)},
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.GEP,
"vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.RANGE, "__name__": "loop"},)},)},
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, name="phi_input", vin=(UPat(UOps.RANGE, name="loop"),)),
UPat(UOps.ALU, BinaryOps.ADD, vin=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
(UPat(UOps.PHI, vin=(UPat(UOps.GEP, name="phi_input",
vin=(UPat(UOps.DEFINE_ACC, vin=(UPat(UOps.RANGE, name="loop"),)),)),
UPat(UOps.ALU, BinaryOps.ADD, vin=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
# deal with UNMUL
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"uop": UOps.CONST, "__name__": "c1"},
{"uop": UOps.UNMUL, "vin": [{"uop": UOps.CONST, "__name__": "c2"}, {"__name__": "v"}]}]},
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"),
UPat(UOps.UNMUL, vin=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
lambda c1,c2,v: v if c1.arg == c2.arg else None),
({"uop": UOps.UNMUL, "vin": ({"uop": UOps.CONST, "__name__": "zero", "arg": 0}, {})}, lambda zero: zero),
({"__name__": "root", "uop": UOps.CAST, "vin": ({"uop": UOps.UNMUL, "__name__": "unmul"},)},
(UPat(UOps.UNMUL, vin=(UPat(UOps.CONST, name="zero", arg=0), UPat())), lambda zero: zero),
(UPat(UOps.CAST, name="root", vin=(UPat(UOps.UNMUL, name="unmul"),)),
lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
# max on special can go away (TODO: special should be variable, same thing applies)
({"uop": UOps.ALU, "arg": BinaryOps.MAX, "vin": [{"__name__": "c", "uop": UOps.CONST}, {"__name__": "s", "uop": UOps.SPECIAL}]},
(UPat(UOps.ALU, BinaryOps.MAX, [UPat(UOps.CONST, name="c"), UPat(UOps.SPECIAL, name="s")]),
lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
# const rules
({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)),
({"__name__": "root", "uop": UOps.CAST, "vin": {"__name__": "c", "uop": UOps.CONST}}, lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.CAST, name="root", vin=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "__name__": "acc"}, {"__name__": "acc"})}, lambda acc: UOp.const(acc.dtype, acc.arg[0])),
({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "vin": tuple()}, {"__name__": "x"})}, lambda x: x),
({"uop": UOps.PHI, "vin": ({"uop": UOps.CONST}, {"__name__": "x"})}, lambda x: x),
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.const(acc.dtype, acc.arg[0])),
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, vin=tuple()), UPat(name="x"))), lambda x: x),
(UPat(UOps.PHI, vin=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
({"__name__": "root", "uop": UOps.DEFINE_ACC, "vin": tuple()}, lambda root: UOp.const(root.dtype, root.arg[0])),
({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "x", "uop": UOps.CONST},)}, lambda root,x: UOp.const(root.dtype, x.arg)),
(UPat(UOps.DEFINE_ACC, name="root", vin=tuple()), lambda root: UOp.const(root.dtype, root.arg[0])),
(UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
# max -2147483648
({"uop": UOps.ALU, "arg": BinaryOps.MAX, "dtype": dtypes.int, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -2147483648}]}, lambda x: x),
(UPat(UOps.ALU, BinaryOps.MAX, dtype=dtypes.int, vin=[UPat(name="x"), UPat(UOps.CONST, -2147483648)]), lambda x: x),
# -(-x) -> x
({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)})}, lambda x: x),
(UPat(UOps.ALU, UnaryOps.NEG, (UPat(UOps.ALU, UnaryOps.NEG, (UPat(name="x"),)))), lambda x: x),
# x+-y -> x-y
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
(UPat(UOps.ALU, BinaryOps.ADD, (UPat(name="x"), UPat(UOps.ALU, UnaryOps.NEG, name="my"))),
lambda x, my: x-my.vin[0]),
# -1*x -> -x
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]}, lambda x: -x),
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(name="x"), UPat(UOps.CONST, -1)]), lambda x: -x),
# bool < False is always false, True < bool is always false
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({}, {"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": False})}, lambda x: x),
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": True}, {})},
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(), UPat(UOps.CONST, False, name="x", dtype=dtypes.bool))), lambda x: x),
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.CONST, True, name="x", dtype=dtypes.bool), UPat())),
lambda x: UOp.const(dtypes.bool, False)),
# a conditional with the same results either way is a noop, also fold const conditionals
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({}, {"__name__": "val"}, {"__name__": "val"})}, lambda val: val),
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"__name__": "gate", "uop": UOps.CONST}, {"__name__": "c0"}, {"__name__": "c1"})},
(UPat(UOps.ALU, TernaryOps.WHERE, (UPat(), UPat(name="val"), UPat(name="val"))), lambda val: val),
(UPat(UOps.ALU, TernaryOps.WHERE, (UPat(UOps.CONST, name="gate"), UPat(name="c0"), UPat(name="c1"))),
lambda gate, c0, c1: c0 if gate.arg else c1),
# ** constant folding **
({"__name__": "root", "uop": UOps.ALU, "vin": {"uop": UOps.CONST}},
(UPat(UOps.ALU, name="root", vin=UPat(UOps.CONST)),
lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
# ** self folding **
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 0}]}, lambda x: x), # x+0 -> x or 0+x -> x
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 1}]}, lambda x: x), # x*1 -> x or 1*x -> x
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 0})}, lambda x: x), # x-0 -> x
({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 1})}, lambda x: x), # x/1 -> x
({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": -1})}, lambda x: -x), # x/-1 -> -x
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="x"), UPat(UOps.CONST, 0)]), lambda x: x), # x+0 -> x or 0+x -> x
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(name="x"), UPat(UOps.CONST, 1)]), lambda x: x), # x*1 -> x or 1*x -> x
(UPat(UOps.ALU, BinaryOps.SUB, (UPat(name="x"), UPat(UOps.CONST, 0))), lambda x: x), # x-0 -> x
(UPat(UOps.ALU, BinaryOps.DIV, (UPat(name="x"), UPat(UOps.CONST, 1))), lambda x: x), # x/1 -> x
(UPat(UOps.ALU, BinaryOps.DIV, (UPat(name="x"), UPat(UOps.CONST, -1))), lambda x: -x), # x/-1 -> -x
# ** zero folding **
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(), UPat(UOps.CONST, 0, name="c")]), lambda c: c), # x*0 -> 0 or 0*x -> 0
(UPat(UOps.ALU, BinaryOps.SUB, (UPat(name="x"), UPat(name="x"))), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
# ** load/store folding **
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"},
{"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})}, lambda buf, idx: UOp(UOps.NOOP)),
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"),
UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))), lambda buf, idx: UOp(UOps.NOOP)),
# ** two stage add/sub folding **
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.ADD,
"vin": [{"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST}]}, {"__name__": "c2", "uop": UOps.CONST}]},
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.ALU, BinaryOps.ADD,
[UPat(name="x"), UPat(UOps.CONST, name="c1")]), UPat(UOps.CONST, name="c2")]),
lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.SUB,
"vin": ({"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST})}, {"__name__": "c2", "uop": UOps.CONST}]},
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.ALU, BinaryOps.SUB,
(UPat(name="x"), UPat(UOps.CONST, name="c1"))), UPat(UOps.CONST, name="c2")]),
lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.SUB, x.dtype, [c2.arg, c1.arg]))),
# TODO: can do the invert of this (flip alt/load) when we fix double ops
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.ALU, "arg": TernaryOps.WHERE,
"vin": ({"__name__": "gate"}, {"__name__": "alt"}, {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})})},
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.ALU, TernaryOps.WHERE,
(UPat(name="gate"), UPat(name="alt"), UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))))),
lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))),
# store float4/float2 directly (remove CAST/GEP)
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(4))})},
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.CAST, vin=
tuple(UPat(UOps.GEP, i, vin=(UPat(name="val"),)) for i in range(4))))),
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(2))})},
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.CAST, vin=
tuple(UPat(UOps.GEP, i, vin=(UPat(name="val"),)) for i in range(2))))),
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
# CAST-PHI-GEP -> PHI-CAST
({"__name__": "root", "uop": UOps.CAST, "vin":
tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(4))},
(UPat(UOps.CAST, name="root", vin=
tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
({"__name__": "root", "uop": UOps.CAST, "vin":
tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(2))},
(UPat(UOps.CAST, name="root", vin=
tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
# NEG/CMPLT -> CMPLT
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)},
{"__name__": "c", "uop": UOps.CONST, "dtype": dtypes.int})},
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.ALU, UnaryOps.NEG, (UPat(name="x"),)),
UPat(UOps.CONST, name="c", dtype=dtypes.int))),
lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)),
# cast NOOP (NOTE: it's str to deal with PtrDType)
({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
(UPat(UOps.CAST, name="root"), lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
])
# *** uop graph ***

View File

@@ -5,7 +5,7 @@ from tinygrad.helpers import DEBUG
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
from tinygrad.codegen.uops import UOpGraph, PatternMatcher, UPat
from tinygrad.renderer import Renderer, TensorCore
def render_val(x, dtype):
@@ -229,46 +229,46 @@ class PTXRenderer(Renderer):
return self.render_kernel(kernel, name, bufs, c.items())
ptx_matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.MUL, "dtype": set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
"vin": [{"__name__": "const", "uop": UOps.CONST, "arg": set([2**i for i in range(64)])}, {"__name__": "mul"}]},
(UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHL)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.DIV, "dtype": set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
"vin": [{"__name__": "const", "uop": UOps.CONST, "arg": set([2**i for i in range(64)])}, {"__name__": "div"}]},
(UPat(UOps.ALU, BinaryOps.DIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHR)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPNE, "vin": ({"dtype": dtypes.bool},{})},
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"),
lambda root: UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD,
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
(UPat(UOps.ALU, BinaryOps.ADD,
[UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool,
vin=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, vin=(UPat(),UPat())),
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g", "dtype": dtypes.int})},
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
# ptr_ar (load/store)
({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
{"uop": UOps.ALU, "arg": BinaryOps.ADD,"vin":[{"__name__": "alu"}, {"__name__": "const", "uop":UOps.CONST}]})},
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
UPat(UOps.ALU, BinaryOps.ADD, vin=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
lambda root, alu, const: UOp(root.uop, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
UOp.const(const.dtype, root.vin[0].dtype.itemsize)*const)+root.vin[2:])),
({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
{"__name__": "const", "uop":UOps.CONST})},
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
UPat(UOps.CONST, name="const"))),
lambda root, const: UOp(root.uop, root.dtype, (root.vin[0].cast(dtypes.int64),
UOp.const(dtypes.int64, const.arg * root.vin[0].dtype.itemsize),
)+root.vin[2:])),
({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
{"__name__": "alu"})}, # no const here
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
UPat(name="alu"))), # no const here
lambda root, alu: UOp(root.uop, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
UOp.const(dtypes.int64, 0))+root.vin[2:])),