mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
expand UOps with rewrite rules (#5501)
* expand UOps with rewrite rules [run_process_replay] * progress * much closer * close, way less bugs * bunch of expander tests * fix contract * ops tests pass * fix barrier * mostly passing * bitcast in expanded ops * support more expand merges * all tests pass maybe * fix empty EXPAND * fix LIN fuzzing * add ALL_SAME assert * all same * all same work * raise CompileError * pass fuzz linearizer * revert whitespace * fix nv tensor core test * fix mypy * bug fix * fuzzer passes * put tests back * expand arg to idx
This commit is contained in:
@@ -270,15 +270,14 @@ class TestUOpGraph(TestUOps):
|
||||
self.assertEqual(endranges[-1].src[0], ranges[0])
|
||||
|
||||
def expander_rewrite(sink):
|
||||
#from tinygrad.codegen.uopgraph import expander, constant_folder
|
||||
#together = PatternMatcher(expander.patterns + constant_folder.patterns)
|
||||
#return graph_rewrite(sink, together)
|
||||
out = UOpGraph(UOp(UOps.SINK, None, (sink,)))
|
||||
out.linearize()
|
||||
return out.uops[-1]
|
||||
from tinygrad.codegen.uopgraph import expander, constant_folder
|
||||
together = PatternMatcher(expander.patterns + constant_folder.patterns)
|
||||
return graph_rewrite(sink, together)
|
||||
#out = UOpGraph(UOp(UOps.SINK, None, (sink,)))
|
||||
#out.linearize()
|
||||
#return out.uops[-1]
|
||||
|
||||
class TestExpander(unittest.TestCase):
|
||||
@unittest.skip
|
||||
def test_expand_add_broadcast(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
sink = expander_rewrite(e1+3)
|
||||
@@ -292,7 +291,6 @@ class TestExpander(unittest.TestCase):
|
||||
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
|
||||
|
||||
@unittest.skip
|
||||
def test_contract_axis_1(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,))
|
||||
@@ -302,7 +300,6 @@ class TestExpander(unittest.TestCase):
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,8,12])
|
||||
self.assertListEqual([x.arg for x in sink.src[3].src], [3,7,11,15])
|
||||
|
||||
@unittest.skip
|
||||
def test_contract_axis_2(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,))
|
||||
@@ -312,7 +309,6 @@ class TestExpander(unittest.TestCase):
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,1,2,3])
|
||||
self.assertListEqual([x.arg for x in sink.src[3].src], [12,13,14,15])
|
||||
|
||||
@unittest.skip
|
||||
def test_contract_mid(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,))
|
||||
@@ -324,7 +320,6 @@ class TestExpander(unittest.TestCase):
|
||||
self.assertListEqual([x.arg for x in sink.src[2].src], [4,6])
|
||||
self.assertListEqual([x.arg for x in sink.src[3].src], [5,7])
|
||||
|
||||
@unittest.skip
|
||||
def test_expand_same_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
|
||||
@@ -332,7 +327,6 @@ class TestExpander(unittest.TestCase):
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,5,10,15])
|
||||
|
||||
@unittest.skip
|
||||
def test_expand_different_axis(self, flip=False):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
|
||||
@@ -357,7 +351,6 @@ class TestExpander(unittest.TestCase):
|
||||
assert sink.op is UOps.CONST
|
||||
self.assertEqual(sink.arg, 3*4)
|
||||
|
||||
@unittest.skip
|
||||
def test_double_expand(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
|
||||
@@ -367,7 +360,6 @@ class TestExpander(unittest.TestCase):
|
||||
assert sink.arg == ((1, 2), (2, 4))
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7])
|
||||
|
||||
@unittest.skip
|
||||
def test_double_expand_reverse(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
|
||||
@@ -377,7 +369,6 @@ class TestExpander(unittest.TestCase):
|
||||
assert sink.arg == ((1, 4), (2, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7])
|
||||
|
||||
@unittest.skip
|
||||
def test_double_expand_middle(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ImageDType
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, all_same, dedup, TRANSCENDENTAL, CI
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI
|
||||
from tinygrad.codegen.uops import UOp, UOps, END_FOR_UOP, type_verify
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
|
||||
@@ -68,122 +68,7 @@ class PatternMatcher:
|
||||
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
||||
return None
|
||||
|
||||
def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]:
|
||||
# just in case, dedup expands
|
||||
expands = dedup(expands)
|
||||
|
||||
# get children and define_accs
|
||||
children = defaultdict(list)
|
||||
define_accs = []
|
||||
for p in parents:
|
||||
if p.op is UOps.PHI:
|
||||
wmma_reduce_axes = flatten([x.arg[7] for x in p.parents if x.op is UOps.WMMA])
|
||||
parent_expands_for_acc = [x.arg[0][0] for x in p.parents if x in expands and x.arg[0][0] not in wmma_reduce_axes]
|
||||
define_accs.append((p.src[0], parent_expands_for_acc))
|
||||
for x in p.src:
|
||||
children[x].append(p)
|
||||
|
||||
# get nodes on the path from root to the expand node
|
||||
on_path: Dict[UOp, None] = {}
|
||||
search = expands[:]
|
||||
while len(search):
|
||||
t = search.pop(0)
|
||||
for cc in children[t]:
|
||||
if cc in on_path: continue
|
||||
on_path[cc] = None
|
||||
search.append(cc)
|
||||
|
||||
# toposort the nodes on the path
|
||||
# TODO: library!
|
||||
in_degree: DefaultDict[UOp, int] = defaultdict(int)
|
||||
for n in on_path:
|
||||
for x in children[n]:
|
||||
in_degree[x] += 1
|
||||
toposort: List[UOp] = []
|
||||
search2 = [p for p in on_path if in_degree[p] == 0]
|
||||
seen: Set[UOp] = set()
|
||||
while len(search2):
|
||||
n = search2.pop(0)
|
||||
if n in seen: continue
|
||||
toposort.append(n)
|
||||
for x in children[n]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0:
|
||||
search2.append(x)
|
||||
|
||||
# get replacements by index
|
||||
replacements: Dict[int, List[int]] = {}
|
||||
for r in expands:
|
||||
if r.arg[0][0] in replacements: assert len(replacements[r.arg[0][0]]) == len(r.src)
|
||||
else: replacements[r.arg[0][0]] = list(range(0, len(r.src)))
|
||||
|
||||
# get nodes on the path from root to the expand node
|
||||
rps = list(itertools.product(*replacements.values()))
|
||||
|
||||
acc_number = 0
|
||||
replaces: List[Dict[UOp, UOp]] = []
|
||||
acc_cache: Dict[Tuple[Tuple[UOp, int, int], ...], UOp] = {}
|
||||
for rp in rps:
|
||||
rpk = dict(zip(replacements.keys(), rp))
|
||||
replace = {r:r.src[rpk[r.arg[0][0]]] for r in expands}
|
||||
for d, acc_parents in define_accs:
|
||||
acc_index = tuple((d,x,rpk[x]) for x in acc_parents)
|
||||
if acc_index in acc_cache:
|
||||
replace[d] = acc_cache[acc_index]
|
||||
else:
|
||||
replace[d] = acc_cache[acc_index] = UOp(d.op, d.dtype, d.src, d.arg + (acc_number,))
|
||||
acc_number += 1
|
||||
replaces.append(replace)
|
||||
|
||||
for cc in toposort:
|
||||
if cc.op is UOps.BARRIER:
|
||||
super_replace = UOp(cc.op, cc.dtype, sum([tuple(replace.get(x, x) for x in cc.src) for replace in replaces], ()), cc.arg)
|
||||
for replace in replaces:
|
||||
replace[cc] = super_replace
|
||||
else:
|
||||
for replace in replaces:
|
||||
tcc = replace.get(cc, cc) # NOTE: handle expands that are already replaced
|
||||
replace[cc] = UOp(tcc.op, tcc.dtype, tuple(replace.get(x, x) for x in tcc.src), tcc.arg)
|
||||
return [x.get(base, base) for x in replaces]
|
||||
|
||||
# ***** reduce+image+contract handling *****
|
||||
|
||||
def expand_wmma(wmma):
|
||||
expands = [x for x in wmma.parents if x.op is UOps.EXPAND and (x.arg[0][0] in wmma.arg[-1] or x.arg[0][0] in wmma.arg[-2])]
|
||||
if len(expands) == 0: return None
|
||||
new_uops = expand_nodes(wmma.sparents, expands, wmma)
|
||||
# TODO: assert that these are all the same. they have to be
|
||||
return new_uops[0]
|
||||
|
||||
acc_number = 0
|
||||
def replace_reduce(root):
|
||||
global acc_number
|
||||
expands = [x for x in root.src[1:] if x.op is UOps.EXPAND]
|
||||
|
||||
# add other expands for float4. TODO: should be a faster way
|
||||
expand_args = [x.arg[0][0] for x in expands]
|
||||
new_expands = [x for x in root.parents if x.op is UOps.EXPAND and x.arg[0][0] in expand_args]
|
||||
expands = dedup(expands + new_expands)
|
||||
|
||||
if len(expands):
|
||||
new_uops = expand_nodes(root.parents, expands, root.src[0])
|
||||
else:
|
||||
new_uops = [root.src[0]]
|
||||
|
||||
const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar()))
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x not in expands), (acc_number,))
|
||||
acc_number += 1
|
||||
ret = acc
|
||||
for xx in new_uops: ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, xx)
|
||||
return UOp(UOps.PHI, ret.dtype, (acc, ret))
|
||||
|
||||
def replace_contract(root:UOp):
|
||||
parents, dtype = root.parents, cast(DType, root.dtype)
|
||||
expands: List[UOp] = [x for x in parents if x.op is UOps.EXPAND and x.arg[0][0] in root.arg]
|
||||
assert all_same(expand_lens := [dtype.count] + [len(x.src) for x in expands]), expand_lens
|
||||
ret = expand_nodes(parents, expands, root.src[0])
|
||||
if len(ret) == 1: ret = ret*dtype.count # TODO: why is this needed?
|
||||
return UOp(UOps.VECTORIZE, dtype, tuple(ret))
|
||||
# ***** image handling *****
|
||||
|
||||
def fix_image_idx(ls:UOp):
|
||||
if ls.src[1].dtype is None or ls.src[1].dtype.count != 1: return None
|
||||
@@ -203,25 +88,6 @@ def fix_image_idx(ls:UOp):
|
||||
return ret
|
||||
return UOp(ls.op, ls.dtype, (ls.src[0], image_idx) + ls.src[2:], ls.arg)
|
||||
|
||||
def cast_reduce(cst):
|
||||
if cst.dtype.scalar() == cst.dtype: return None # not for normal CAST. TODO: the merging one shouldn't be CAST
|
||||
if not all_same([(x.arg, x.src[1:]) for x in cst.src]): return None
|
||||
fst_red = cst.src[0]
|
||||
red = UOp(UOps.VECTORIZE, cst.dtype, tuple(x.src[0] for x in cst.src))
|
||||
return UOp(UOps.REDUCE, red.dtype, (red,) + fst_red.src[1:], fst_red.arg)
|
||||
|
||||
contractor = PatternMatcher([
|
||||
# contracts
|
||||
(UOp(UOps.CONTRACT).name("root"), replace_contract),
|
||||
])
|
||||
|
||||
reducer = PatternMatcher([
|
||||
(UOp(UOps.REDUCE).name("root"), replace_reduce),
|
||||
(UOp(UOps.WMMA).name("wmma"), expand_wmma),
|
||||
# image indexing. TODO: why can't this just go after the float stuff?
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx),
|
||||
])
|
||||
|
||||
# ***** float4 handling *****
|
||||
|
||||
def float4_expand_load(load, buf, ex, idx=UOp.const(dtypes.int, 0), idx2=None):
|
||||
@@ -483,6 +349,123 @@ constant_folder = PatternMatcher([
|
||||
|
||||
constant_folder_w_f4 = PatternMatcher(constant_folder.patterns + float4_folding.patterns)
|
||||
|
||||
# *** uop expander ***
|
||||
|
||||
def _expand_arg_to_idx(arg:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]):
|
||||
idx, mul = 0, 1
|
||||
for axis,m in arg[::-1]:
|
||||
idx += rpk[axis] * mul
|
||||
mul *= m
|
||||
return idx
|
||||
|
||||
def do_expand(root:UOp):
|
||||
if root.op is UOps.REDUCE:
|
||||
if root.src[0].op is not UOps.EXPAND: return None
|
||||
reduce_expand_args = flatten([x.arg for x in root.src[1:] if x.op is UOps.EXPAND])
|
||||
expand_args = tuple(x for x in root.src[0].arg if x not in reduce_expand_args)
|
||||
if len(expand_args) == 0: return None
|
||||
dont_expand_args = tuple(x for x in root.src[0].arg if x in reduce_expand_args)
|
||||
else:
|
||||
expands = [x for x in root.src if x.op is UOps.EXPAND]
|
||||
if len(expands) == 0: return None
|
||||
expand_args = tuple(sorted(dedup(flatten([x.arg for x in expands]))))
|
||||
if root.op is UOps.WMMA:
|
||||
dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in root.arg[-2])
|
||||
expand_args = tuple(x for x in expand_args if x not in dont_expand_args)
|
||||
else:
|
||||
dont_expand_args = ()
|
||||
new_srcs = []
|
||||
for choices in itertools.product(*[range(x[1]) for x in expand_args]):
|
||||
rpk = dict(zip([x[0] for x in expand_args], choices))
|
||||
new_src = []
|
||||
for src in root.src:
|
||||
if src.op is UOps.EXPAND:
|
||||
lnew_src = []
|
||||
for lchoices in itertools.product(*[range(x[1]) for x in dont_expand_args]):
|
||||
lrpk = {**rpk, **dict(zip([x[0] for x in dont_expand_args], lchoices))}
|
||||
lnew_src.append(src.src[_expand_arg_to_idx(src.arg, lrpk)])
|
||||
if len(dont_expand_args):
|
||||
if root.op is UOps.WMMA:
|
||||
new_src.append(lnew_src[0]) # TODO: is this always right?
|
||||
else:
|
||||
new_src.append(UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))
|
||||
else:
|
||||
assert len(lnew_src) == 1
|
||||
new_src.append(lnew_src[0])
|
||||
else:
|
||||
new_src.append(src)
|
||||
new_srcs.append(UOp(root.op, root.dtype, tuple(new_src), root.arg))
|
||||
if root.op is UOps.EXPAND:
|
||||
expand_args, old_expand_args = tuple(sorted(root.arg+expand_args)), expand_args
|
||||
assert len(expand_args) == (len(old_expand_args) + len(root.arg))
|
||||
new_new_srcs = []
|
||||
for choices in itertools.product(*[range(x[1]) for x in expand_args]):
|
||||
rpk = dict(zip([x[0] for x in expand_args], choices))
|
||||
new_new_srcs.append(new_srcs[_expand_arg_to_idx(old_expand_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)])
|
||||
new_srcs = new_new_srcs
|
||||
assert prod([x[1] for x in expand_args]) == len(new_srcs)
|
||||
return UOp(UOps.EXPAND, root.dtype, tuple(new_srcs), expand_args)
|
||||
|
||||
acc_number = 0
|
||||
def do_reduce_with_expand(root):
|
||||
global acc_number
|
||||
expands = [x for x in root.src[1:] if x.op is UOps.EXPAND]
|
||||
expands_reduce = [x for x in expands if root.src[0].op is UOps.EXPAND and all(y in root.src[0].arg for y in x.arg)]
|
||||
expands_non_reduce = [x for x in expands if x not in expands_reduce]
|
||||
const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar()))
|
||||
ret = acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x.op is not UOps.EXPAND), (acc_number,))
|
||||
acc_number += 1
|
||||
if len(expands_reduce):
|
||||
assert root.src[0].op is UOps.EXPAND
|
||||
expand_reduce_args = dedup(flatten([x.arg for x in expands_reduce]))
|
||||
assert prod([y[1] for y in expand_reduce_args]) == len(root.src[0].src)
|
||||
for xx in root.src[0].src:
|
||||
ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, xx)
|
||||
else:
|
||||
ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, root.src[0])
|
||||
ret = UOp(UOps.PHI, ret.dtype, (acc, ret))
|
||||
if len(expands_non_reduce): ret = ret * prod([sz for _,sz in flatten([x.arg for x in expands_non_reduce])])
|
||||
return ret
|
||||
|
||||
def do_contract(con:UOp):
|
||||
ex = con.src[0]
|
||||
assert con.dtype is not None
|
||||
# CONTRACT without EXPAND repeats the element VECTORIZED
|
||||
if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
|
||||
# simple CONTRACT and EXPAND cancel out
|
||||
if len(ex.arg) == 1 and len(con.arg) == 1 and ex.arg[0][0] in con.arg: return UOp(UOps.VECTORIZE, con.dtype, ex.src)
|
||||
# complex CONTRACT may only remove one axis from EXPAND
|
||||
assert len(con.arg) == 1, "contract arg one is all that's supported"
|
||||
try:
|
||||
split_index = [x[0] for x in ex.arg].index(con.arg[0])
|
||||
except ValueError:
|
||||
# CONTRACT without EXPAND (still) repeats the element VECTORIZED
|
||||
return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
|
||||
assert con.dtype.count == ex.arg[split_index][1], "contract arg must match"
|
||||
number_after = prod([x[1] for x in ex.arg[split_index+1:]])
|
||||
to_join = [ex.src[i:i+number_after] for i in range(0, len(ex.src), number_after)]
|
||||
srcs = []
|
||||
for i in range(0, len(to_join), con.dtype.count):
|
||||
srcs += [UOp(UOps.VECTORIZE, con.dtype, tuple(src)) for src in zip(*to_join[i:i+con.dtype.count])]
|
||||
return UOp(UOps.EXPAND, con.dtype, tuple(srcs), tuple(x for x in ex.arg if x[0] != con.arg[0]))
|
||||
|
||||
expander = PatternMatcher([
|
||||
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
|
||||
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root"), do_expand),
|
||||
(UOp(UOps.REDUCE).name("root"), do_reduce_with_expand),
|
||||
(UOp(UOps.CONTRACT).name("con"), do_contract),
|
||||
# remove EXPANDs from SINK
|
||||
(UOp(UOps.SINK).name("root"),
|
||||
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg)
|
||||
if len(a:=tuple(flatten(x.src if x.op is UOps.EXPAND else (x,) for x in root.src))) != len(root.src) else None),
|
||||
# BARRIERs aren't actually expanded
|
||||
(UOp(UOps.BARRIER, src=(UOp(UOps.EXPAND).name("ex"),)), lambda ex: UOp(UOps.EXPAND, None, (UOp(UOps.BARRIER, None, ex.src),)*len(ex.src), ex.arg)),
|
||||
# image indexing (needs to be here)
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx),
|
||||
# empty EXPAND is NOOP
|
||||
(UOp(UOps.EXPAND, src=(UOp.var('x'),), arg=()), lambda x: x),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]):
|
||||
@@ -565,16 +548,9 @@ class UOpGraph:
|
||||
sink = graph_rewrite(sink, self.folder)
|
||||
if extra_pm: sink = graph_rewrite(sink, PatternMatcher(self.folder.patterns+extra_pm.patterns))
|
||||
|
||||
# expand
|
||||
UOpGraph.cnt += 1
|
||||
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0):
|
||||
# do contracts/reduces
|
||||
sink = graph_rewrite(sink, contractor)
|
||||
sink = graph_rewrite(sink, reducer)
|
||||
|
||||
# do upcasts (after reduce unrolls and rewrites)
|
||||
expands = list(sorted(x for x in sink.sparents if x.op is UOps.EXPAND))
|
||||
new_nodes = expand_nodes(sink.sparents, expands, sink)
|
||||
sink = UOp(UOps.SINK, None, tuple(flatten([x.src for x in new_nodes]))) # merge the sinks
|
||||
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, expander)
|
||||
|
||||
# do graph rewrite (2)
|
||||
sink = graph_rewrite(sink, self.folder)
|
||||
|
||||
Reference in New Issue
Block a user