mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
simpler block reorder function [pr] (#8031)
* simpler block reorder function [pr]
* simpler
* block_reorder in substitute, so wasteful otherwise
* extend and count
* leave push logic for same order
* sort new ctx
* less loop
* Revert "less loop"
This reverts commit 30249d097a.
This commit is contained in:
@@ -26,11 +26,14 @@ class BasicBlock:
|
||||
|
||||
def append_to_block(ctx:Tuple[Dict[UOp, Tuple[UOp, ...]], Dict[UOp, List[UOp]]], x:UOp):
|
||||
block_ctxs, children = ctx
|
||||
in_this_block = set(x.arg.lst)
|
||||
|
||||
# collections to build
|
||||
new_srcs: List[UOp] = []
|
||||
to_append: List[UOp] = []
|
||||
old_blocks: Dict[Tuple[UOp, ...], UOp] = {}
|
||||
new_blocks: Dict[Tuple[UOp, ...], List[UOp]] = {}
|
||||
in_this_block = set(x.arg.lst)
|
||||
|
||||
for u in x.src:
|
||||
if u.op is Ops.BLOCK:
|
||||
# merge sibling blocks. NOTE: blocks must only have one output source
|
||||
@@ -40,7 +43,7 @@ def append_to_block(ctx:Tuple[Dict[UOp, Tuple[UOp, ...]], Dict[UOp, List[UOp]]],
|
||||
# if it can go in blocks and all its children are in the block, we add it to the block
|
||||
if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
|
||||
# if it's the same context, we place the UOp in this block and append the parents to its srcs
|
||||
new_srcs += list(u.src)
|
||||
new_srcs.extend(u.src)
|
||||
to_append.append(u)
|
||||
else:
|
||||
# if it's a different context, we create a new block with this UOp
|
||||
@@ -52,11 +55,10 @@ def append_to_block(ctx:Tuple[Dict[UOp, Tuple[UOp, ...]], Dict[UOp, List[UOp]]],
|
||||
|
||||
for rng,lst in new_blocks.items():
|
||||
srcs = flatten(y.src for y in lst)
|
||||
if (old_block:=old_blocks.get(rng, None)) is not None:
|
||||
if (old_block:=old_blocks.pop(rng, None)) is not None:
|
||||
# NOTE: order shouldn't matter here
|
||||
srcs += list(old_block.src)
|
||||
lst += list(old_block.arg.lst)
|
||||
del old_blocks[rng]
|
||||
srcs.extend(old_block.src)
|
||||
lst.extend(old_block.arg.lst)
|
||||
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst)))
|
||||
lrng = list(rng)
|
||||
for r in rng[::-1]:
|
||||
@@ -80,13 +82,13 @@ def block_merge(ctx, x:UOp):
|
||||
if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0:
|
||||
# find the parent block that has the BLOCKSTART in the ctx
|
||||
parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx]
|
||||
assert len(parent_blocks) <= 1, "should never have two parent blocks"
|
||||
if len(parent_blocks) == 1:
|
||||
parent_block = parent_blocks[0]
|
||||
# range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
|
||||
early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src)
|
||||
return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src,
|
||||
BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops)))
|
||||
assert not len(parent_blocks)
|
||||
|
||||
new_srcs: List[UOp] = []
|
||||
to_append: List[UOp] = []
|
||||
@@ -95,45 +97,36 @@ def block_merge(ctx, x:UOp):
|
||||
for u in x.src:
|
||||
if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)):
|
||||
# NOTE: this can't appear in srcs twice or it would be a BLOCKFORK
|
||||
new_ctx += u.arg.ctx
|
||||
new_srcs += list(u.src)
|
||||
to_append += u.arg.lst
|
||||
elif u.op is Ops.BLOCKFORK and len([y for y in x.src if y is u]) == u.arg: # block fork appears # of times in srcs
|
||||
new_ctx += tuple(y for y in u.arg.ctx if y not in x.arg.ctx)
|
||||
new_srcs.extend(u.src)
|
||||
to_append.extend(u.arg.lst)
|
||||
elif u.op is Ops.BLOCKFORK and x.src.count(u) == u.arg: # block fork appears # of times in srcs
|
||||
if u not in placed:
|
||||
new_srcs += list(u.src)
|
||||
new_srcs.extend(u.src)
|
||||
placed.add(u)
|
||||
else:
|
||||
# keep it in srcs
|
||||
new_srcs.append(u)
|
||||
if len(to_append) == 0 and len(placed) == 0: return None
|
||||
return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(dedup(new_ctx)), tuple(to_append)+x.arg.lst, x.arg.end))
|
||||
return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(sorted(new_ctx, key=lambda x: x.tuplize)), tuple(to_append)+x.arg.lst, x.arg.end))
|
||||
|
||||
pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
|
||||
|
||||
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
||||
def block_reorder(ctx, in_block:UOp):
|
||||
# only visit each block once
|
||||
if in_block in ctx: return None
|
||||
ctx[in_block] = None
|
||||
|
||||
# get local children
|
||||
def block_reorder(in_block:UOp):
|
||||
in_this_block = set(in_block.arg.lst)
|
||||
local_children: DefaultDict[UOp, List[UOp]] = collections.defaultdict(list)
|
||||
in_degree: DefaultDict[UOp, int] = collections.defaultdict(int)
|
||||
for u in in_block.arg.lst:
|
||||
priorities:Dict[UOp, int] = {}
|
||||
|
||||
# get local children and assign priorities
|
||||
for u in reversed(in_block.arg.lst):
|
||||
for s in u.src:
|
||||
if s in in_this_block:
|
||||
local_children[s].append(u)
|
||||
in_degree[u] += 1
|
||||
|
||||
# assign priorities
|
||||
priorities:Dict[UOp, int] = {}
|
||||
def get_priority(u:UOp):
|
||||
# put loads in the beginning of the block
|
||||
priority = -1000 if u.op is Ops.LOAD else 0
|
||||
# prevent priority inversion
|
||||
return min([priority] + [priorities[x] for x in local_children[u]])
|
||||
for u in in_block.arg.lst[::-1]: priorities[u] = get_priority(u)
|
||||
# put loads in the beginning of the block and prevent priority inversion
|
||||
priorities[u] = min([-1000 if u.op is Ops.LOAD else 0] + [priorities[x] for x in local_children[u]])
|
||||
|
||||
# placement queue
|
||||
queue:List[Tuple[int, Tuple, UOp]] = []
|
||||
@@ -150,11 +143,10 @@ def block_reorder(ctx, in_block:UOp):
|
||||
for u in local_children[x]:
|
||||
in_degree[u] -= 1
|
||||
if in_degree[u] == 0: push(u)
|
||||
|
||||
assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}"
|
||||
return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst)))
|
||||
|
||||
pm_block_reorder = PatternMatcher([(UPat(Ops.BLOCK, name="in_block"), block_reorder),])
|
||||
|
||||
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
|
||||
@@ -181,7 +173,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
else:
|
||||
# flow though everything else
|
||||
this_block_ctx += temp_block_ctxs[s]
|
||||
temp_block_ctxs[u] = dedup(sorted(this_block_ctx, key=lambda x: x.tuplize))
|
||||
temp_block_ctxs[u] = sorted(dedup(this_block_ctx), key=lambda x: x.tuplize)
|
||||
|
||||
# make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE
|
||||
block_ctxs: Dict[UOp, Tuple[UOp, ...]] = {}
|
||||
@@ -215,7 +207,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
sink = sink.substitute(new_forks)
|
||||
|
||||
# reorder ops in block for speed
|
||||
sink = graph_rewrite(sink, pm_block_reorder, ctx={})
|
||||
sink = sink.substitute({u:newu for u in sink.toposort if u.op is Ops.BLOCK and (newu:=block_reorder(u)) is not u})
|
||||
|
||||
# final rewrite to merge all blocks into one
|
||||
sink = graph_rewrite(sink, pm_block_merge, ctx=children)
|
||||
|
||||
Reference in New Issue
Block a user