diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index ebe86ad77d..d3c73559c6 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -2,12 +2,11 @@ from __future__ import annotations from typing import List, Dict, Tuple, Optional import collections from dataclasses import dataclass -from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite +from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp from tinygrad.dtype import dtypes, PtrDType from tinygrad.helpers import dedup, flatten, partition -DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, - Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART} +DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block} def disp(y:UOp) -> str: if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0]) @@ -29,32 +28,44 @@ def append_to_block(ctx:Tuple[Dict[UOp, Tuple[UOp, ...]], Dict[UOp, List[UOp]]], block_ctxs, children = ctx new_srcs: List[UOp] = [] to_append: List[UOp] = [] + old_blocks: Dict[Tuple[UOp, ...], UOp] = {} new_blocks: Dict[Tuple[UOp, ...], List[UOp]] = {} - bb: BasicBlock = x.arg - in_this_block = set(bb.lst) + in_this_block = set(x.arg.lst) for u in x.src: - if u.op in DONT_PLACE_IN_BLOCK or len([y for y in children[u] if y not in in_this_block]) > 0: - # if it's a fork or not placed, we don't place it - new_srcs.append(u) - elif (block_ctx:=block_ctxs[u]) == bb.ctx: - # if it's the same context, we place the UOp in this block and append the parents to it's srcs - new_srcs += list(u.src) - to_append.append(u) + if u.op is Ops.BLOCK: + # merge sibling blocks. NOTE: blocks must only have one output source + assert u.arg.ctx not in old_blocks, "sibiling should never have been created" + old_blocks[u.arg.ctx] = u + elif u.op not in DONT_PLACE_IN_BLOCK and set(children[u]).issubset(in_this_block): + # if it can go in blocks and all its children are in the block, we add it to the block + if (block_ctx:=block_ctxs[u]) == x.arg.ctx: + # if it's the same context, we place the UOp in this block and append the parents to its srcs + new_srcs += list(u.src) + to_append.append(u) + else: + # if it's a different context, we create a new block with this UOp + new_blocks.setdefault(block_ctx, []).append(u) else: - # otherwise, we create a new block with this UOp - new_blocks.setdefault(block_ctx, []).append(u) + # otherwise, we keep it in the srcs + new_srcs.append(u) if len(to_append) == 0 and len(new_blocks) == 0: return None for rng,lst in new_blocks.items(): - new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(flatten(y.src for y in lst))), BasicBlock(rng, tuple(lst))) + srcs = flatten(y.src for y in lst) + if (old_block:=old_blocks.get(rng, None)) is not None: + # NOTE: order shouldn't matter here + srcs += list(old_block.src) + lst += list(old_block.arg.lst) + del old_blocks[rng] + new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst))) lrng = list(rng) for r in rng[::-1]: - if r not in bb.ctx and r.op is not Ops.BLOCKSTART: + if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART: lrng.remove(r) new_block = UOp(Ops.BLOCKEND, src=(new_block,), arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r)) new_srcs.append(new_block) - return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(new_srcs)), BasicBlock(bb.ctx, tuple(to_append)+bb.lst)) + return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst)) make_basic_blocks = PatternMatcher([ (UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))), diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 8f00161d18..713563b021 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -173,6 +173,7 @@ class GroupOp: # meta ops Meta = {Ops.COPY, Ops.EMPTY, Ops.BUFFER_VIEW} Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID} + Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART} # BinaryOps that can be flipped Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}