local reordering in block (#8029)

* local reordering in block

* load (and parents) is highest priority

* minor loads in order

* comments

* explicit depth

* simpler

* matters less, but store early too
This commit is contained in:
George Hotz
2024-12-04 15:11:29 +08:00
committed by GitHub
parent 4cb630ac1c
commit bb98bae751
2 changed files with 52 additions and 4 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import List, Dict, Tuple, Optional
import collections
from typing import List, Dict, Tuple, Optional, DefaultDict
import collections, heapq
from dataclasses import dataclass
from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp
from tinygrad.dtype import dtypes, PtrDType
@@ -110,6 +110,51 @@ def block_merge(ctx, x:UOp):
pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
def block_reorder(ctx, in_block:UOp):
# only visit each block once
if in_block in ctx: return None
ctx[in_block] = None
# get local children
in_this_block = set(in_block.arg.lst)
local_children: DefaultDict[UOp, List[UOp]] = collections.defaultdict(list)
in_degree: DefaultDict[UOp, int] = collections.defaultdict(int)
for u in in_block.arg.lst:
for s in u.src:
if s in in_this_block:
local_children[s].append(u)
in_degree[u] += 1
# assign priorities
priorities:Dict[UOp, int] = {}
def get_priority(u:UOp):
# put loads in the beginning of the block
priority = -1000 if u.op is Ops.LOAD else 0
# prevent priority inversion
return min([priority] + [priorities[x] for x in local_children[u]])
for u in in_block.arg.lst[::-1]: priorities[u] = get_priority(u)
# placement queue
queue:List[Tuple[int, Tuple, UOp]] = []
def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
# place the first ones that don't have deps
for u in in_block.arg.lst:
if u not in in_degree: push(u)
newlst = []
while queue:
_,_,x = heapq.heappop(queue)
newlst.append(x)
for u in local_children[x]:
in_degree[u] -= 1
if in_degree[u] == 0: push(u)
assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}"
return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst)))
pm_block_reorder = PatternMatcher([(UPat(Ops.BLOCK, name="in_block"), block_reorder),])
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
@@ -169,6 +214,9 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
for u in v: new_forks[u] = out
sink = sink.substitute(new_forks)
# reorder ops in block for speed
sink = graph_rewrite(sink, pm_block_reorder, ctx={})
# final rewrite to merge all blocks into one
sink = graph_rewrite(sink, pm_block_merge, ctx=children)

View File

@@ -127,8 +127,9 @@ class Ops(FastEnum):
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
# loads before math
# load/store before math
LOAD = auto()
STORE = auto()
# early INDEX
INDEX = auto()
@@ -144,7 +145,6 @@ class Ops(FastEnum):
WHERE = auto(); MULACC = auto() # noqa: E702
# assignment ops
STORE = auto()
ASSIGN = auto()
BIND = auto()