mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
rewrite the linearizer (#9885)
* random speedups [pr] * speeding up linearizer * test_gemm passes * progress * test_gemm passes * working * simpler * blockstart unneeded * simpler * bugfix * work * don't compare * faster * progress * cleanups * work * cleanups * working * reorder * name is dumb * fix tests * lin2 works * clean ctx * mostly bottom up * passes * same speed now * new lin is faster * dedup * lines and tuples * track that * lin * revert that * tests should pass * merge siblings * cleaner expression * only lin2 * finally, some speed * simpler * fix unmergables with blockends
This commit is contained in:
@@ -1,148 +1,21 @@
|
||||
from __future__ import annotations
|
||||
import collections, heapq
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp
|
||||
import heapq
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import dedup, partition, all_same, flatten
|
||||
from tinygrad.spec import type_verify
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, flatten, partition
|
||||
|
||||
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
|
||||
|
||||
def disp(y:UOp) -> str:
|
||||
if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
|
||||
if y.op is Ops.IF: return f'IF{id(y)}'
|
||||
if y.op is Ops.RANGE: return str(y.arg)
|
||||
return "<NONE>"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BasicBlock:
|
||||
ctx: tuple[UOp, ...]
|
||||
lst: tuple[UOp, ...]
|
||||
end: UOp|None = None
|
||||
def __lt__(self, o:BasicBlock): return tuple(x.tuplize for x in self.ctx+self.lst) < tuple(x.tuplize for x in o.ctx+o.lst)
|
||||
def __repr__(self):
|
||||
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\
|
||||
f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
|
||||
|
||||
def append_to_block(ctx:tuple[dict[UOp, tuple[UOp, ...]], dict[UOp, list[UOp]]], x:UOp):
|
||||
block_ctxs, children = ctx
|
||||
in_this_block = set(x.arg.lst)
|
||||
|
||||
# collections to build
|
||||
new_srcs: list[UOp] = []
|
||||
to_append: list[UOp] = []
|
||||
old_blocks: dict[tuple[UOp, ...], UOp] = {}
|
||||
new_blocks: dict[tuple[UOp, ...], list[UOp]] = {}
|
||||
|
||||
seen_u = set()
|
||||
for u in x.src:
|
||||
if u.op is Ops.BLOCK:
|
||||
if u not in seen_u:
|
||||
# merge sibling blocks. NOTE: blocks must only have one output source
|
||||
assert u.arg.ctx not in old_blocks, "sibling should never have been created"
|
||||
old_blocks[u.arg.ctx] = u
|
||||
elif u.op not in DONT_PLACE_IN_BLOCK and set(children[u]).issubset(in_this_block):
|
||||
if u not in seen_u:
|
||||
# if it can go in blocks and all its children are in the block, we add it to the block
|
||||
if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
|
||||
# if it's the same context, we place the UOp in this block and append the parents to its srcs
|
||||
new_srcs.extend(u.src)
|
||||
to_append.append(u)
|
||||
else:
|
||||
# if it's a different context, we create a new block with this UOp
|
||||
new_blocks.setdefault(block_ctx, []).append(u)
|
||||
else:
|
||||
# otherwise, we keep it in the srcs
|
||||
new_srcs.append(u)
|
||||
seen_u.add(u)
|
||||
if len(to_append) == 0 and len(new_blocks) == 0: return None
|
||||
|
||||
for rng,lst in new_blocks.items():
|
||||
srcs = flatten(y.src for y in lst)
|
||||
if (old_block:=old_blocks.pop(rng, None)) is not None:
|
||||
# NOTE: order shouldn't matter here
|
||||
srcs.extend(old_block.src)
|
||||
lst.extend(old_block.arg.lst)
|
||||
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst)))
|
||||
lrng = list(rng)
|
||||
for r in rng[::-1]:
|
||||
if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
|
||||
lrng.remove(r)
|
||||
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
|
||||
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
|
||||
new_srcs.append(new_block)
|
||||
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
|
||||
|
||||
make_basic_blocks = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))), (UPat(Ops.BLOCK, name="x"), append_to_block),
|
||||
])
|
||||
|
||||
def block_merge(ctx, x:UOp):
|
||||
# ctx is children here
|
||||
if x.op is Ops.BLOCKEND:
|
||||
# if it's a BLOCKEND, see if we are done with placement. if all the children of the range are in here
|
||||
in_this_block = set(x.arg.lst)
|
||||
if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0:
|
||||
# find the parent block that has the BLOCKSTART in the ctx
|
||||
parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx]
|
||||
assert len(parent_blocks) <= 1, "should never have two parent blocks"
|
||||
if len(parent_blocks) == 1:
|
||||
parent_block = parent_blocks[0]
|
||||
# range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
|
||||
early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src)
|
||||
# NOTE: we have to add a barrier at the start if barrier is used in the range
|
||||
if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE:
|
||||
late_ops = [UOp(Ops.BARRIER)] + late_ops
|
||||
return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src,
|
||||
BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops)))
|
||||
|
||||
new_srcs: list[UOp] = []
|
||||
to_append: list[UOp] = []
|
||||
new_ctx = x.arg.ctx
|
||||
placed = set()
|
||||
for u in x.src:
|
||||
if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)):
|
||||
# NOTE: this can't appear in srcs twice or it would be a BLOCKFORK
|
||||
new_ctx += tuple(y for y in u.arg.ctx if y not in x.arg.ctx)
|
||||
new_srcs.extend(u.src)
|
||||
to_append.extend(u.arg.lst)
|
||||
elif u.op is Ops.BLOCKFORK and x.src.count(u) == u.arg: # block fork appears # of times in srcs
|
||||
if u not in placed:
|
||||
new_srcs.extend(u.src)
|
||||
placed.add(u)
|
||||
else:
|
||||
# keep it in srcs
|
||||
new_srcs.append(u)
|
||||
if len(to_append) == 0 and len(placed) == 0: return None
|
||||
return UOp(x.op, dtypes.void, tuple(new_srcs),
|
||||
BasicBlock(tuple(dedup(sorted(new_ctx, key=lambda x: x.tuplize))), tuple(to_append)+x.arg.lst, x.arg.end))
|
||||
|
||||
pm_block_merge = PatternMatcher([
|
||||
(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),
|
||||
# double BLOCKFORK multiplies the forking (like if there's 3 forks into 2 forks, that's 6 total forks)
|
||||
(UPat(Ops.BLOCKFORK, name="f", src=(UPat(Ops.BLOCKFORK, name="f2"),)), lambda f,f2: f.replace(src=f2.src, arg=f.arg*f2.arg)),
|
||||
])
|
||||
|
||||
def block_finalize(block:UOp):
|
||||
if len(block.src) == 0: return None
|
||||
_uops = sorted(dedup(block.src), key=lambda x: x.tuplize)
|
||||
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
|
||||
_uops += block.arg.lst
|
||||
# strip the SINK
|
||||
assert _uops[-1].op is Ops.SINK, "doesn't end with SINK"
|
||||
return UOp(Ops.BLOCK, arg=BasicBlock((), tuple(_uops)))
|
||||
|
||||
pm_block_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="block"), block_finalize)])
|
||||
|
||||
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
||||
def block_reorder(in_block:UOp):
|
||||
in_this_block = set(in_block.arg.lst)
|
||||
local_children: collections.defaultdict[UOp, list[UOp]] = collections.defaultdict(list)
|
||||
in_degree: collections.defaultdict[UOp, int] = collections.defaultdict(int)
|
||||
def block_reorder(lst:list[UOp]) -> list[UOp]:
|
||||
in_this_block = set(lst)
|
||||
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree: defaultdict[UOp, int] = defaultdict(int)
|
||||
priorities:dict[UOp, int] = {}
|
||||
|
||||
# get local children and assign priorities
|
||||
for u in reversed(in_block.arg.lst):
|
||||
for u in reversed(lst):
|
||||
for s in u.src:
|
||||
if s in in_this_block:
|
||||
local_children[s].append(u)
|
||||
@@ -158,7 +31,7 @@ def block_reorder(in_block:UOp):
|
||||
def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
|
||||
|
||||
# place the first ones that don't have deps
|
||||
for u in in_block.arg.lst:
|
||||
for u in lst:
|
||||
if u not in in_degree: push(u)
|
||||
|
||||
newlst = []
|
||||
@@ -169,68 +42,184 @@ def block_reorder(in_block:UOp):
|
||||
in_degree[u] -= 1
|
||||
if in_degree[u] == 0: push(u)
|
||||
|
||||
assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}"
|
||||
return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst)))
|
||||
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
|
||||
return newlst
|
||||
|
||||
def upsettingly_promote_blockend(be:UOp):
|
||||
new_srcs = tuple(b.replace(arg=BasicBlock(be.arg.ctx, b.arg.lst)) if b.op is Ops.BLOCK else b for b in be.src)
|
||||
return be.replace(src=new_srcs) if be.src != new_srcs else None
|
||||
pm_force_upcast_block = PatternMatcher([(UPat(Ops.BLOCKEND, name="be"), upsettingly_promote_blockend)])
|
||||
# ***** basic block *****
|
||||
|
||||
def disp(y:UOp) -> str:
|
||||
if y.op is Ops.IF: return f'IF{id(y)}'
|
||||
if y.op is Ops.RANGE: return str(y.arg)
|
||||
return "<NONE>"
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class BasicBlock2:
|
||||
lst: tuple[UOp, ...]
|
||||
ctx: tuple[UOp, ...]
|
||||
end: UOp|None = None
|
||||
cnt: int = 0
|
||||
child_ctx: tuple[UOp, ...]|None = None
|
||||
def __lt__(self, _:BasicBlock2): raise RuntimeError("no comparing basic blocks")
|
||||
def __repr__(self):
|
||||
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\
|
||||
f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\
|
||||
f"{len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
|
||||
def last_ctx(self): return self.child_ctx if self.child_ctx is not None else self.ctx
|
||||
|
||||
def _sort_ctx(inp): return tuple(sorted(dedup(inp), key=lambda x: x.tuplize))
|
||||
|
||||
# ***** block context *****
|
||||
|
||||
@dataclass
|
||||
class BlockContext:
|
||||
child_count: dict[UOp, int]
|
||||
block_ctxs: dict[UOp, tuple[UOp, ...]]
|
||||
child_ctxs: dict[UOp, tuple[UOp, ...]]
|
||||
def last_ctx(self, u): return ret if (ret:=self.child_ctxs.get(u)) is not None else self.block_ctxs[u]
|
||||
@staticmethod
|
||||
def from_sink(sink:UOp) -> BlockContext:
|
||||
# get children and all block contexts
|
||||
ctx = BlockContext({}, {}, {})
|
||||
for u in sink.toposort:
|
||||
this_block_ctx: list[UOp] = []
|
||||
ctx.child_count[u] = 0
|
||||
|
||||
# get children and accumulate the last_ctx
|
||||
for s in u.src:
|
||||
# NOTE: if a parent appears multiple times in the src, it counts multiple times as a child
|
||||
ctx.child_count[s] += 1
|
||||
this_block_ctx += ctx.last_ctx(s)
|
||||
|
||||
# save the block ctx
|
||||
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx)
|
||||
|
||||
# RANGE/IF add to the next ctx
|
||||
# STORE/ASSIGN subtract from the next ctx
|
||||
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
|
||||
elif u.op is Ops.STORE:
|
||||
# ugh, deal with non-reduce locals. probably wrong
|
||||
if isinstance(u.src[0].dtype, PtrDType) and u.src[0].dtype.local:
|
||||
idx_context, store_context = ctx.last_ctx(u.src[0]), ctx.last_ctx(u.src[1])
|
||||
ctx.child_ctxs[u] = tuple([y for y in store_context if y not in idx_context and y.op is Ops.RANGE])
|
||||
else: ctx.child_ctxs[u] = ()
|
||||
elif u.op is Ops.ASSIGN:
|
||||
assert u.src[0].op is Ops.DEFINE_ACC
|
||||
ctx.child_ctxs[u] = tuple([y for y in ctx.last_ctx(u.src[1]) if y not in u.src[0].src[1:]])
|
||||
return ctx
|
||||
|
||||
# ***** make blocks *****
|
||||
|
||||
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}
|
||||
|
||||
def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...]):
|
||||
ends_to_add = [z for z in new_ctx if z not in current_ctx]
|
||||
while len(ends_to_add):
|
||||
r:UOp = ends_to_add.pop(-1)
|
||||
new_ctx = tuple([z for z in new_ctx if z is not r])
|
||||
end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))
|
||||
base_block = UOp(Ops.BLOCKEND, src=(base_block,), arg=BasicBlock2((end_uop,), tuple(new_ctx), end=r, cnt=1))
|
||||
return base_block
|
||||
|
||||
def make_block_bottom_up(ctx:BlockContext, x:UOp):
|
||||
if x.op is Ops.BLOCKSTART:
|
||||
current_ctx, child_ctx = x.arg
|
||||
lst = list(x.src)
|
||||
child_count = 1
|
||||
else:
|
||||
current_ctx, child_count, child_ctx = ctx.block_ctxs[x], ctx.child_count[x], ctx.child_ctxs.get(x, None)
|
||||
lst = [x]
|
||||
|
||||
# count of times we've seen this block, or a seed for a new block if we can't merge it
|
||||
unmergable: defaultdict[UOp, int] = defaultdict(int)
|
||||
blockseeds = defaultdict(list)
|
||||
|
||||
# add the srcs of this to the frontier
|
||||
# NOTE: things may be in here multiple times, that's okay
|
||||
frontier_nodes = list(flatten(y.src[::-1] for y in lst))
|
||||
while len(frontier_nodes):
|
||||
u = frontier_nodes.pop(0)
|
||||
if u.op not in DONT_PLACE_IN_BLOCK and ctx.child_count[u] == unmergable[u]+1:
|
||||
# count is correct
|
||||
if (newctx:=ctx.block_ctxs[u]) == current_ctx:
|
||||
# block has same context, merge it, and put the srcs on the frontier
|
||||
lst.append(u)
|
||||
frontier_nodes.extend(u.src[::-1])
|
||||
else:
|
||||
# block has different context, add it to blockseeds
|
||||
blockseeds[(newctx, ctx.child_ctxs.get(u, None))].append(u)
|
||||
del unmergable[u]
|
||||
else:
|
||||
# count is incorrect (or it's DONT_PLACE_IN_BLOCK), add it to unmergable
|
||||
unmergable[u] += 1
|
||||
|
||||
# add unmergables to sources
|
||||
srcs = []
|
||||
for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs[u], current_ctx)]*cnt
|
||||
|
||||
# add blockseeds, with blockends as needed
|
||||
for (new_ctx, new_child_ctx), v in blockseeds.items():
|
||||
base_block = UOp(Ops.BLOCKSTART, src=tuple(v), arg=(new_ctx, new_child_ctx))
|
||||
srcs.append(add_blockends(base_block, new_ctx, current_ctx))
|
||||
|
||||
lst = block_reorder(lst[::-1])
|
||||
bb = BasicBlock2(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
|
||||
return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb)
|
||||
|
||||
block_create = PatternMatcher([(
|
||||
UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up)
|
||||
])
|
||||
|
||||
# ***** block merging ****
|
||||
|
||||
def merge_block(x:UOp):
|
||||
unmergable_blocks, mergable_blocks = [], []
|
||||
mergable_dict: defaultdict[UOp, int] = defaultdict(int)
|
||||
for y in x.src:
|
||||
if y.op is Ops.BLOCK and x.op is Ops.BLOCK and x.arg.ctx == y.arg.ctx: mergable_dict[y] += 1
|
||||
elif y.op is Ops.BLOCK and x.op is Ops.BLOCKEND and x.arg.end in y.arg.ctx: mergable_dict[y] += 1
|
||||
else: unmergable_blocks.append(y)
|
||||
for k,v in mergable_dict.items():
|
||||
if v == k.arg.cnt: mergable_blocks.append(k)
|
||||
else: unmergable_blocks.extend([k]*v)
|
||||
if len(mergable_blocks) == 0: return None
|
||||
del mergable_dict
|
||||
|
||||
# create the block
|
||||
arg = replace(x.arg, lst=tuple(flatten([y.arg.lst for y in mergable_blocks]))+x.arg.lst)
|
||||
return UOp(x.op, src=tuple(flatten([y.src for y in mergable_blocks])+unmergable_blocks), arg=arg)
|
||||
|
||||
def remove_blockend(x:UOp):
|
||||
# if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it
|
||||
if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None
|
||||
|
||||
parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]
|
||||
assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
|
||||
if len(parent_blocks) > 0:
|
||||
parent_block = parent_blocks[0]
|
||||
assert len(parent_blocks) == parent_block.arg.cnt
|
||||
# range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
|
||||
early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src)
|
||||
# NOTE: we have to add a barrier at the start if barrier is used in the range
|
||||
if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE:
|
||||
late_ops = [UOp(Ops.BARRIER)] + late_ops
|
||||
arg = BasicBlock2(tuple(early_ops)+parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)
|
||||
return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg)
|
||||
|
||||
block_merge = PatternMatcher([
|
||||
(UPat((Ops.BLOCK, Ops.BLOCKEND), name="x"), merge_block),
|
||||
(UPat(Ops.BLOCKEND, name="x"), remove_blockend),
|
||||
])
|
||||
|
||||
# ****** finalize ******
|
||||
|
||||
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
|
||||
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
|
||||
# get children and all block contexts
|
||||
temp_block_ctxs: dict[UOp, list[UOp]] = {}
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
for u in sink.toposort:
|
||||
this_block_ctx: list[UOp] = []
|
||||
for s in u.src:
|
||||
# save children
|
||||
children.setdefault(s, []).append(u)
|
||||
# compute block ctx
|
||||
if s.op in {Ops.RANGE, Ops.IF}: this_block_ctx.append(s)
|
||||
# don't flow (fully) through assign and store
|
||||
elif s.op is Ops.STORE:
|
||||
idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s]
|
||||
this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE]
|
||||
elif s.op is Ops.ASSIGN:
|
||||
# flow though assign, but remove the ranges used in the assign
|
||||
assert s.src[0].op is Ops.DEFINE_ACC
|
||||
this_block_ctx += [x for x in temp_block_ctxs[s.src[1]] if x not in s.src[0].src[1:]]
|
||||
else:
|
||||
# flow though everything else
|
||||
this_block_ctx += temp_block_ctxs[s]
|
||||
temp_block_ctxs[u] = sorted(dedup(this_block_ctx), key=lambda x: x.tuplize)
|
||||
# get block context
|
||||
ctx = BlockContext.from_sink(sink)
|
||||
|
||||
# make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE
|
||||
block_ctxs: dict[UOp, tuple[UOp, ...]] = {}
|
||||
for u in sink.toposort:
|
||||
block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + tuple(temp_block_ctxs[u])) if u.op in {Ops.IF, Ops.RANGE} else tuple(temp_block_ctxs[u])
|
||||
|
||||
# TODO: there's probably a clever way to remove this while loop
|
||||
while 1:
|
||||
sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children))
|
||||
|
||||
# add BLOCKFORK (slow!)
|
||||
block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK]))
|
||||
non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]))
|
||||
forks = {}
|
||||
for u,child_count in block_parent_count.items():
|
||||
if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents:
|
||||
# TODO: this is copied from append_to_block
|
||||
new_block = UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,)))
|
||||
rng = block_ctxs[u]
|
||||
lrng = list(rng)
|
||||
for r in rng[::-1]:
|
||||
# if none of the children of u are in the same context, we need a BLOCKEND
|
||||
if all(r not in block_ctxs[c] for c in children[u]) and r.op is not Ops.BLOCKSTART:
|
||||
lrng.remove(r)
|
||||
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
|
||||
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
|
||||
forks[u] = UOp(Ops.BLOCKFORK, src=(new_block,), arg=child_count)
|
||||
if not len(forks): break
|
||||
sink = sink.substitute(forks)
|
||||
# wrap all uops in blocks, already reordered
|
||||
sink = graph_rewrite(sink, block_create, ctx=ctx, name="Linearizer: Create Blocks", bottom_up=True)
|
||||
|
||||
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
|
||||
blockends_to_arg: dict[UOp, list[UOp]] = {}
|
||||
@@ -240,26 +229,19 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
|
||||
for k,v in blockends_to_arg.items():
|
||||
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
|
||||
if len(v) > 1:
|
||||
out = UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCKEND, src=tuple(flatten(x.src for x in v)),
|
||||
arg=BasicBlock(tuple(dedup(flatten([y.arg.ctx for y in v]))), v[0].arg.lst, k)),), arg=len(v))
|
||||
bb = BasicBlock2(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=len(v))
|
||||
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
|
||||
for u in v: new_forks[u] = out
|
||||
sink = sink.substitute(new_forks)
|
||||
|
||||
# reorder ops in block for speed
|
||||
sink = sink.substitute({u:newu for u in sink.toposort if u.op is Ops.BLOCK and (newu:=block_reorder(u)) is not u})
|
||||
# merge blockends
|
||||
sink = graph_rewrite(sink, block_merge, name="Linearizer: Merge Blocks")
|
||||
assert sink.op is Ops.BLOCK and all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src)
|
||||
|
||||
# final rewrite to merge all blocks into one
|
||||
sink = graph_rewrite(sink, pm_block_merge, ctx=children)
|
||||
|
||||
# if there's BLOCKENDs left in the graph, we might have to merge. TODO: is there a better way to handle this?
|
||||
while (newsink := graph_rewrite(sink, pm_force_upcast_block)) is not sink:
|
||||
sink = graph_rewrite(newsink, pm_block_merge, ctx=children, name="bad_merge")
|
||||
|
||||
# there should just be one block left, with a few parents with 0 srcs (now done in a rewriter)
|
||||
sink = graph_rewrite(sink, pm_block_finalize)
|
||||
# place the early things
|
||||
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
|
||||
|
||||
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
||||
if not skip_check: type_verify(sink.arg.lst)
|
||||
if not skip_check: type_verify(lst)
|
||||
|
||||
# return the list. TODO: refactor to return the UOp
|
||||
return list(sink.arg.lst)
|
||||
return lst
|
||||
|
||||
Reference in New Issue
Block a user