expand UOps with rewrite rules (#5501)

* expand UOps with rewrite rules [run_process_replay]

* progress

* much closer

* close, way less bugs

* bunch of expander tests

* fix contract

* ops tests pass

* fix barrier

* mostly passing

* bitcast in expanded ops

* support more expand merges

* all tests pass maybe

* fix empty EXPAND

* fix LIN fuzzing

* add ALL_SAME assert

* all same

* all same work

* raise CompileError

* pass fuzz linearizer

* revert whitespace

* fix nv tensor core test

* fix mypy

* bug fix

* fuzzer passes

* put tests back

* expand arg to idx
This commit is contained in:
George Hotz
2024-07-17 10:17:50 -07:00
committed by GitHub
parent 158221b36b
commit 1242b302fa
2 changed files with 127 additions and 160 deletions

View File

@@ -270,15 +270,14 @@ class TestUOpGraph(TestUOps):
self.assertEqual(endranges[-1].src[0], ranges[0])
def expander_rewrite(sink):
#from tinygrad.codegen.uopgraph import expander, constant_folder
#together = PatternMatcher(expander.patterns + constant_folder.patterns)
#return graph_rewrite(sink, together)
out = UOpGraph(UOp(UOps.SINK, None, (sink,)))
out.linearize()
return out.uops[-1]
from tinygrad.codegen.uopgraph import expander, constant_folder
together = PatternMatcher(expander.patterns + constant_folder.patterns)
return graph_rewrite(sink, together)
#out = UOpGraph(UOp(UOps.SINK, None, (sink,)))
#out.linearize()
#return out.uops[-1]
class TestExpander(unittest.TestCase):
@unittest.skip
def test_expand_add_broadcast(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = expander_rewrite(e1+3)
@@ -292,7 +291,6 @@ class TestExpander(unittest.TestCase):
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
@unittest.skip
def test_contract_axis_1(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,))
@@ -302,7 +300,6 @@ class TestExpander(unittest.TestCase):
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,8,12])
self.assertListEqual([x.arg for x in sink.src[3].src], [3,7,11,15])
@unittest.skip
def test_contract_axis_2(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,))
@@ -312,7 +309,6 @@ class TestExpander(unittest.TestCase):
self.assertListEqual([x.arg for x in sink.src[0].src], [0,1,2,3])
self.assertListEqual([x.arg for x in sink.src[3].src], [12,13,14,15])
@unittest.skip
def test_contract_mid(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,))
@@ -324,7 +320,6 @@ class TestExpander(unittest.TestCase):
self.assertListEqual([x.arg for x in sink.src[2].src], [4,6])
self.assertListEqual([x.arg for x in sink.src[3].src], [5,7])
@unittest.skip
def test_expand_same_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
@@ -332,7 +327,6 @@ class TestExpander(unittest.TestCase):
assert sink.op is UOps.EXPAND and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [0,5,10,15])
@unittest.skip
def test_expand_different_axis(self, flip=False):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
@@ -357,7 +351,6 @@ class TestExpander(unittest.TestCase):
assert sink.op is UOps.CONST
self.assertEqual(sink.arg, 3*4)
@unittest.skip
def test_double_expand(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
@@ -367,7 +360,6 @@ class TestExpander(unittest.TestCase):
assert sink.arg == ((1, 2), (2, 4))
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7])
@unittest.skip
def test_double_expand_reverse(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
@@ -377,7 +369,6 @@ class TestExpander(unittest.TestCase):
assert sink.arg == ((1, 4), (2, 2))
self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7])
@unittest.skip
def test_double_expand_middle(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from tinygrad.dtype import dtypes, DType, PtrDType, ImageDType
from tinygrad.shape.symbolic import Variable
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu
from tinygrad.helpers import DEBUG, getenv, flatten, all_same, dedup, TRANSCENDENTAL, CI
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI
from tinygrad.codegen.uops import UOp, UOps, END_FOR_UOP, type_verify
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
@@ -68,122 +68,7 @@ class PatternMatcher:
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
return None
def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]:
# just in case, dedup expands
expands = dedup(expands)
# get children and define_accs
children = defaultdict(list)
define_accs = []
for p in parents:
if p.op is UOps.PHI:
wmma_reduce_axes = flatten([x.arg[7] for x in p.parents if x.op is UOps.WMMA])
parent_expands_for_acc = [x.arg[0][0] for x in p.parents if x in expands and x.arg[0][0] not in wmma_reduce_axes]
define_accs.append((p.src[0], parent_expands_for_acc))
for x in p.src:
children[x].append(p)
# get nodes on the path from root to the expand node
on_path: Dict[UOp, None] = {}
search = expands[:]
while len(search):
t = search.pop(0)
for cc in children[t]:
if cc in on_path: continue
on_path[cc] = None
search.append(cc)
# toposort the nodes on the path
# TODO: library!
in_degree: DefaultDict[UOp, int] = defaultdict(int)
for n in on_path:
for x in children[n]:
in_degree[x] += 1
toposort: List[UOp] = []
search2 = [p for p in on_path if in_degree[p] == 0]
seen: Set[UOp] = set()
while len(search2):
n = search2.pop(0)
if n in seen: continue
toposort.append(n)
for x in children[n]:
in_degree[x] -= 1
if in_degree[x] == 0:
search2.append(x)
# get replacements by index
replacements: Dict[int, List[int]] = {}
for r in expands:
if r.arg[0][0] in replacements: assert len(replacements[r.arg[0][0]]) == len(r.src)
else: replacements[r.arg[0][0]] = list(range(0, len(r.src)))
# get nodes on the path from root to the expand node
rps = list(itertools.product(*replacements.values()))
acc_number = 0
replaces: List[Dict[UOp, UOp]] = []
acc_cache: Dict[Tuple[Tuple[UOp, int, int], ...], UOp] = {}
for rp in rps:
rpk = dict(zip(replacements.keys(), rp))
replace = {r:r.src[rpk[r.arg[0][0]]] for r in expands}
for d, acc_parents in define_accs:
acc_index = tuple((d,x,rpk[x]) for x in acc_parents)
if acc_index in acc_cache:
replace[d] = acc_cache[acc_index]
else:
replace[d] = acc_cache[acc_index] = UOp(d.op, d.dtype, d.src, d.arg + (acc_number,))
acc_number += 1
replaces.append(replace)
for cc in toposort:
if cc.op is UOps.BARRIER:
super_replace = UOp(cc.op, cc.dtype, sum([tuple(replace.get(x, x) for x in cc.src) for replace in replaces], ()), cc.arg)
for replace in replaces:
replace[cc] = super_replace
else:
for replace in replaces:
tcc = replace.get(cc, cc) # NOTE: handle expands that are already replaced
replace[cc] = UOp(tcc.op, tcc.dtype, tuple(replace.get(x, x) for x in tcc.src), tcc.arg)
return [x.get(base, base) for x in replaces]
# ***** reduce+image+contract handling *****
def expand_wmma(wmma):
expands = [x for x in wmma.parents if x.op is UOps.EXPAND and (x.arg[0][0] in wmma.arg[-1] or x.arg[0][0] in wmma.arg[-2])]
if len(expands) == 0: return None
new_uops = expand_nodes(wmma.sparents, expands, wmma)
# TODO: assert that these are all the same. they have to be
return new_uops[0]
acc_number = 0
def replace_reduce(root):
global acc_number
expands = [x for x in root.src[1:] if x.op is UOps.EXPAND]
# add other expands for float4. TODO: should be a faster way
expand_args = [x.arg[0][0] for x in expands]
new_expands = [x for x in root.parents if x.op is UOps.EXPAND and x.arg[0][0] in expand_args]
expands = dedup(expands + new_expands)
if len(expands):
new_uops = expand_nodes(root.parents, expands, root.src[0])
else:
new_uops = [root.src[0]]
const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar()))
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x not in expands), (acc_number,))
acc_number += 1
ret = acc
for xx in new_uops: ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, xx)
return UOp(UOps.PHI, ret.dtype, (acc, ret))
def replace_contract(root:UOp):
parents, dtype = root.parents, cast(DType, root.dtype)
expands: List[UOp] = [x for x in parents if x.op is UOps.EXPAND and x.arg[0][0] in root.arg]
assert all_same(expand_lens := [dtype.count] + [len(x.src) for x in expands]), expand_lens
ret = expand_nodes(parents, expands, root.src[0])
if len(ret) == 1: ret = ret*dtype.count # TODO: why is this needed?
return UOp(UOps.VECTORIZE, dtype, tuple(ret))
# ***** image handling *****
def fix_image_idx(ls:UOp):
if ls.src[1].dtype is None or ls.src[1].dtype.count != 1: return None
@@ -203,25 +88,6 @@ def fix_image_idx(ls:UOp):
return ret
return UOp(ls.op, ls.dtype, (ls.src[0], image_idx) + ls.src[2:], ls.arg)
def cast_reduce(cst):
if cst.dtype.scalar() == cst.dtype: return None # not for normal CAST. TODO: the merging one shouldn't be CAST
if not all_same([(x.arg, x.src[1:]) for x in cst.src]): return None
fst_red = cst.src[0]
red = UOp(UOps.VECTORIZE, cst.dtype, tuple(x.src[0] for x in cst.src))
return UOp(UOps.REDUCE, red.dtype, (red,) + fst_red.src[1:], fst_red.arg)
contractor = PatternMatcher([
# contracts
(UOp(UOps.CONTRACT).name("root"), replace_contract),
])
reducer = PatternMatcher([
(UOp(UOps.REDUCE).name("root"), replace_reduce),
(UOp(UOps.WMMA).name("wmma"), expand_wmma),
# image indexing. TODO: why can't this just go after the float stuff?
(UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx),
])
# ***** float4 handling *****
def float4_expand_load(load, buf, ex, idx=UOp.const(dtypes.int, 0), idx2=None):
@@ -483,6 +349,123 @@ constant_folder = PatternMatcher([
constant_folder_w_f4 = PatternMatcher(constant_folder.patterns + float4_folding.patterns)
# *** uop expander ***
def _expand_arg_to_idx(arg:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]):
idx, mul = 0, 1
for axis,m in arg[::-1]:
idx += rpk[axis] * mul
mul *= m
return idx
def do_expand(root:UOp):
if root.op is UOps.REDUCE:
if root.src[0].op is not UOps.EXPAND: return None
reduce_expand_args = flatten([x.arg for x in root.src[1:] if x.op is UOps.EXPAND])
expand_args = tuple(x for x in root.src[0].arg if x not in reduce_expand_args)
if len(expand_args) == 0: return None
dont_expand_args = tuple(x for x in root.src[0].arg if x in reduce_expand_args)
else:
expands = [x for x in root.src if x.op is UOps.EXPAND]
if len(expands) == 0: return None
expand_args = tuple(sorted(dedup(flatten([x.arg for x in expands]))))
if root.op is UOps.WMMA:
dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in root.arg[-2])
expand_args = tuple(x for x in expand_args if x not in dont_expand_args)
else:
dont_expand_args = ()
new_srcs = []
for choices in itertools.product(*[range(x[1]) for x in expand_args]):
rpk = dict(zip([x[0] for x in expand_args], choices))
new_src = []
for src in root.src:
if src.op is UOps.EXPAND:
lnew_src = []
for lchoices in itertools.product(*[range(x[1]) for x in dont_expand_args]):
lrpk = {**rpk, **dict(zip([x[0] for x in dont_expand_args], lchoices))}
lnew_src.append(src.src[_expand_arg_to_idx(src.arg, lrpk)])
if len(dont_expand_args):
if root.op is UOps.WMMA:
new_src.append(lnew_src[0]) # TODO: is this always right?
else:
new_src.append(UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))
else:
assert len(lnew_src) == 1
new_src.append(lnew_src[0])
else:
new_src.append(src)
new_srcs.append(UOp(root.op, root.dtype, tuple(new_src), root.arg))
if root.op is UOps.EXPAND:
expand_args, old_expand_args = tuple(sorted(root.arg+expand_args)), expand_args
assert len(expand_args) == (len(old_expand_args) + len(root.arg))
new_new_srcs = []
for choices in itertools.product(*[range(x[1]) for x in expand_args]):
rpk = dict(zip([x[0] for x in expand_args], choices))
new_new_srcs.append(new_srcs[_expand_arg_to_idx(old_expand_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)])
new_srcs = new_new_srcs
assert prod([x[1] for x in expand_args]) == len(new_srcs)
return UOp(UOps.EXPAND, root.dtype, tuple(new_srcs), expand_args)
acc_number = 0
def do_reduce_with_expand(root):
global acc_number
expands = [x for x in root.src[1:] if x.op is UOps.EXPAND]
expands_reduce = [x for x in expands if root.src[0].op is UOps.EXPAND and all(y in root.src[0].arg for y in x.arg)]
expands_non_reduce = [x for x in expands if x not in expands_reduce]
const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar()))
ret = acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x.op is not UOps.EXPAND), (acc_number,))
acc_number += 1
if len(expands_reduce):
assert root.src[0].op is UOps.EXPAND
expand_reduce_args = dedup(flatten([x.arg for x in expands_reduce]))
assert prod([y[1] for y in expand_reduce_args]) == len(root.src[0].src)
for xx in root.src[0].src:
ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, xx)
else:
ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, root.src[0])
ret = UOp(UOps.PHI, ret.dtype, (acc, ret))
if len(expands_non_reduce): ret = ret * prod([sz for _,sz in flatten([x.arg for x in expands_non_reduce])])
return ret
def do_contract(con:UOp):
ex = con.src[0]
assert con.dtype is not None
# CONTRACT without EXPAND repeats the element VECTORIZED
if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
# simple CONTRACT and EXPAND cancel out
if len(ex.arg) == 1 and len(con.arg) == 1 and ex.arg[0][0] in con.arg: return UOp(UOps.VECTORIZE, con.dtype, ex.src)
# complex CONTRACT may only remove one axis from EXPAND
assert len(con.arg) == 1, "contract arg one is all that's supported"
try:
split_index = [x[0] for x in ex.arg].index(con.arg[0])
except ValueError:
# CONTRACT without EXPAND (still) repeats the element VECTORIZED
return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
assert con.dtype.count == ex.arg[split_index][1], "contract arg must match"
number_after = prod([x[1] for x in ex.arg[split_index+1:]])
to_join = [ex.src[i:i+number_after] for i in range(0, len(ex.src), number_after)]
srcs = []
for i in range(0, len(to_join), con.dtype.count):
srcs += [UOp(UOps.VECTORIZE, con.dtype, tuple(src)) for src in zip(*to_join[i:i+con.dtype.count])]
return UOp(UOps.EXPAND, con.dtype, tuple(srcs), tuple(x for x in ex.arg if x[0] != con.arg[0]))
expander = PatternMatcher([
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root"), do_expand),
(UOp(UOps.REDUCE).name("root"), do_reduce_with_expand),
(UOp(UOps.CONTRACT).name("con"), do_contract),
# remove EXPANDs from SINK
(UOp(UOps.SINK).name("root"),
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg)
if len(a:=tuple(flatten(x.src if x.op is UOps.EXPAND else (x,) for x in root.src))) != len(root.src) else None),
# BARRIERs aren't actually expanded
(UOp(UOps.BARRIER, src=(UOp(UOps.EXPAND).name("ex"),)), lambda ex: UOp(UOps.EXPAND, None, (UOp(UOps.BARRIER, None, ex.src),)*len(ex.src), ex.arg)),
# image indexing (needs to be here)
(UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx),
# empty EXPAND is NOOP
(UOp(UOps.EXPAND, src=(UOp.var('x'),), arg=()), lambda x: x),
])
# *** uop graph ***
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]):
@@ -565,16 +548,9 @@ class UOpGraph:
sink = graph_rewrite(sink, self.folder)
if extra_pm: sink = graph_rewrite(sink, PatternMatcher(self.folder.patterns+extra_pm.patterns))
# expand
UOpGraph.cnt += 1
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0):
# do contracts/reduces
sink = graph_rewrite(sink, contractor)
sink = graph_rewrite(sink, reducer)
# do upcasts (after reduce unrolls and rewrites)
expands = list(sorted(x for x in sink.sparents if x.op is UOps.EXPAND))
new_nodes = expand_nodes(sink.sparents, expands, sink)
sink = UOp(UOps.SINK, None, tuple(flatten([x.src for x in new_nodes]))) # merge the sinks
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, expander)
# do graph rewrite (2)
sink = graph_rewrite(sink, self.folder)