From bb98bae7510e5ffcc8e08365606175205509c018 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 4 Dec 2024 15:11:29 +0800 Subject: [PATCH] local reordering in block (#8029) * local reordering in block * load (and parents) is highest priority * minor loads in order * comments * explicit depth * simpler * matters less, but store early too --- tinygrad/codegen/linearize.py | 52 +++++++++++++++++++++++++++++++++-- tinygrad/ops.py | 4 +-- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index d3c73559c6..6609ca8d37 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Dict, Tuple, Optional -import collections +from typing import List, Dict, Tuple, Optional, DefaultDict +import collections, heapq from dataclasses import dataclass from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp from tinygrad.dtype import dtypes, PtrDType @@ -110,6 +110,51 @@ def block_merge(ctx, x:UOp): pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),]) +# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed +def block_reorder(ctx, in_block:UOp): + # only visit each block once + if in_block in ctx: return None + ctx[in_block] = None + + # get local children + in_this_block = set(in_block.arg.lst) + local_children: DefaultDict[UOp, List[UOp]] = collections.defaultdict(list) + in_degree: DefaultDict[UOp, int] = collections.defaultdict(int) + for u in in_block.arg.lst: + for s in u.src: + if s in in_this_block: + local_children[s].append(u) + in_degree[u] += 1 + + # assign priorities + priorities:Dict[UOp, int] = {} + def get_priority(u:UOp): + # put loads in the beginning of the block + priority = -1000 if u.op is Ops.LOAD else 0 + # prevent priority inversion + return min([priority] + [priorities[x] for x in local_children[u]]) + for u in in_block.arg.lst[::-1]: priorities[u] = get_priority(u) + + # placement queue + queue:List[Tuple[int, Tuple, UOp]] = [] + def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u)) + + # place the first ones that don't have deps + for u in in_block.arg.lst: + if u not in in_degree: push(u) + + newlst = [] + while queue: + _,_,x = heapq.heappop(queue) + newlst.append(x) + for u in local_children[x]: + in_degree[u] -= 1 + if in_degree[u] == 0: push(u) + assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}" + return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst))) + +pm_block_reorder = PatternMatcher([(UPat(Ops.BLOCK, name="in_block"), block_reorder),]) + def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}" @@ -169,6 +214,9 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: for u in v: new_forks[u] = out sink = sink.substitute(new_forks) + # reorder ops in block for speed + sink = graph_rewrite(sink, pm_block_reorder, ctx={}) + # final rewrite to merge all blocks into one sink = graph_rewrite(sink, pm_block_merge, ctx=children) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index da7862b36b..a2aaee2a63 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -127,8 +127,9 @@ class Ops(FastEnum): # UnaryOps CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 - # loads before math + # load/store before math LOAD = auto() + STORE = auto() # early INDEX INDEX = auto() @@ -144,7 +145,6 @@ class Ops(FastEnum): WHERE = auto(); MULACC = auto() # noqa: E702 # assignment ops - STORE = auto() ASSIGN = auto() BIND = auto()