From 8f65c1fafbf55191d45ca22cdd406cf7b1d4ff63 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 4 Dec 2024 17:57:35 +0800 Subject: [PATCH] simpler block reorder function [pr] (#8031) * simpler block reorder function [pr] * simpler * block_reorder in substitute, so wasteful otherwise * extend and count * leave push logic for same order * sort new ctx * less loop * Revert "less loop" This reverts commit 30249d097a8affba07a15c0ba1d6b56c9eea3006. --- tinygrad/codegen/linearize.py | 58 +++++++++++++++-------------------- 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index 6609ca8d37..2fd2d1f25c 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -26,11 +26,14 @@ class BasicBlock: def append_to_block(ctx:Tuple[Dict[UOp, Tuple[UOp, ...]], Dict[UOp, List[UOp]]], x:UOp): block_ctxs, children = ctx + in_this_block = set(x.arg.lst) + + # collections to build new_srcs: List[UOp] = [] to_append: List[UOp] = [] old_blocks: Dict[Tuple[UOp, ...], UOp] = {} new_blocks: Dict[Tuple[UOp, ...], List[UOp]] = {} - in_this_block = set(x.arg.lst) + for u in x.src: if u.op is Ops.BLOCK: # merge sibling blocks. NOTE: blocks must only have one output source @@ -40,7 +43,7 @@ def append_to_block(ctx:Tuple[Dict[UOp, Tuple[UOp, ...]], Dict[UOp, List[UOp]]], # if it can go in blocks and all its children are in the block, we add it to the block if (block_ctx:=block_ctxs[u]) == x.arg.ctx: # if it's the same context, we place the UOp in this block and append the parents to its srcs - new_srcs += list(u.src) + new_srcs.extend(u.src) to_append.append(u) else: # if it's a different context, we create a new block with this UOp @@ -52,11 +55,10 @@ def append_to_block(ctx:Tuple[Dict[UOp, Tuple[UOp, ...]], Dict[UOp, List[UOp]]], for rng,lst in new_blocks.items(): srcs = flatten(y.src for y in lst) - if (old_block:=old_blocks.get(rng, None)) is not None: + if (old_block:=old_blocks.pop(rng, None)) is not None: # NOTE: order shouldn't matter here - srcs += list(old_block.src) - lst += list(old_block.arg.lst) - del old_blocks[rng] + srcs.extend(old_block.src) + lst.extend(old_block.arg.lst) new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst))) lrng = list(rng) for r in rng[::-1]: @@ -80,13 +82,13 @@ def block_merge(ctx, x:UOp): if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0: # find the parent block that has the BLOCKSTART in the ctx parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx] + assert len(parent_blocks) <= 1, "should never have two parent blocks" if len(parent_blocks) == 1: parent_block = parent_blocks[0] # range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if) early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src) return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src, BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops))) - assert not len(parent_blocks) new_srcs: List[UOp] = [] to_append: List[UOp] = [] @@ -95,45 +97,36 @@ def block_merge(ctx, x:UOp): for u in x.src: if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)): # NOTE: this can't appear in srcs twice or it would be a BLOCKFORK - new_ctx += u.arg.ctx - new_srcs += list(u.src) - to_append += u.arg.lst - elif u.op is Ops.BLOCKFORK and len([y for y in x.src if y is u]) == u.arg: # block fork appears # of times in srcs + new_ctx += tuple(y for y in u.arg.ctx if y not in x.arg.ctx) + new_srcs.extend(u.src) + to_append.extend(u.arg.lst) + elif u.op is Ops.BLOCKFORK and x.src.count(u) == u.arg: # block fork appears # of times in srcs if u not in placed: - new_srcs += list(u.src) + new_srcs.extend(u.src) placed.add(u) else: # keep it in srcs new_srcs.append(u) if len(to_append) == 0 and len(placed) == 0: return None - return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(dedup(new_ctx)), tuple(to_append)+x.arg.lst, x.arg.end)) + return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(sorted(new_ctx, key=lambda x: x.tuplize)), tuple(to_append)+x.arg.lst, x.arg.end)) pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),]) # NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed -def block_reorder(ctx, in_block:UOp): - # only visit each block once - if in_block in ctx: return None - ctx[in_block] = None - - # get local children +def block_reorder(in_block:UOp): in_this_block = set(in_block.arg.lst) local_children: DefaultDict[UOp, List[UOp]] = collections.defaultdict(list) in_degree: DefaultDict[UOp, int] = collections.defaultdict(int) - for u in in_block.arg.lst: + priorities:Dict[UOp, int] = {} + + # get local children and assign priorities + for u in reversed(in_block.arg.lst): for s in u.src: if s in in_this_block: local_children[s].append(u) in_degree[u] += 1 - - # assign priorities - priorities:Dict[UOp, int] = {} - def get_priority(u:UOp): - # put loads in the beginning of the block - priority = -1000 if u.op is Ops.LOAD else 0 - # prevent priority inversion - return min([priority] + [priorities[x] for x in local_children[u]]) - for u in in_block.arg.lst[::-1]: priorities[u] = get_priority(u) + # put loads in the beginning of the block and prevent priority inversion + priorities[u] = min([-1000 if u.op is Ops.LOAD else 0] + [priorities[x] for x in local_children[u]]) # placement queue queue:List[Tuple[int, Tuple, UOp]] = [] @@ -150,11 +143,10 @@ def block_reorder(ctx, in_block:UOp): for u in local_children[x]: in_degree[u] -= 1 if in_degree[u] == 0: push(u) + assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}" return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst))) -pm_block_reorder = PatternMatcher([(UPat(Ops.BLOCK, name="in_block"), block_reorder),]) - def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}" @@ -181,7 +173,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: else: # flow though everything else this_block_ctx += temp_block_ctxs[s] - temp_block_ctxs[u] = dedup(sorted(this_block_ctx, key=lambda x: x.tuplize)) + temp_block_ctxs[u] = sorted(dedup(this_block_ctx), key=lambda x: x.tuplize) # make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE block_ctxs: Dict[UOp, Tuple[UOp, ...]] = {} @@ -215,7 +207,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: sink = sink.substitute(new_forks) # reorder ops in block for speed - sink = graph_rewrite(sink, pm_block_reorder, ctx={}) + sink = sink.substitute({u:newu for u in sink.toposort if u.op is Ops.BLOCK and (newu:=block_reorder(u)) is not u}) # final rewrite to merge all blocks into one sink = graph_rewrite(sink, pm_block_merge, ctx=children)