mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
local reordering in block (#8029)
* local reordering in block * load (and parents) is highest priority * minor loads in order * comments * explicit depth * simpler * matters less, but store early too
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
import collections
|
||||
from typing import List, Dict, Tuple, Optional, DefaultDict
|
||||
import collections, heapq
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
@@ -110,6 +110,51 @@ def block_merge(ctx, x:UOp):
|
||||
|
||||
pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
|
||||
|
||||
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
||||
def block_reorder(ctx, in_block:UOp):
|
||||
# only visit each block once
|
||||
if in_block in ctx: return None
|
||||
ctx[in_block] = None
|
||||
|
||||
# get local children
|
||||
in_this_block = set(in_block.arg.lst)
|
||||
local_children: DefaultDict[UOp, List[UOp]] = collections.defaultdict(list)
|
||||
in_degree: DefaultDict[UOp, int] = collections.defaultdict(int)
|
||||
for u in in_block.arg.lst:
|
||||
for s in u.src:
|
||||
if s in in_this_block:
|
||||
local_children[s].append(u)
|
||||
in_degree[u] += 1
|
||||
|
||||
# assign priorities
|
||||
priorities:Dict[UOp, int] = {}
|
||||
def get_priority(u:UOp):
|
||||
# put loads in the beginning of the block
|
||||
priority = -1000 if u.op is Ops.LOAD else 0
|
||||
# prevent priority inversion
|
||||
return min([priority] + [priorities[x] for x in local_children[u]])
|
||||
for u in in_block.arg.lst[::-1]: priorities[u] = get_priority(u)
|
||||
|
||||
# placement queue
|
||||
queue:List[Tuple[int, Tuple, UOp]] = []
|
||||
def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
|
||||
|
||||
# place the first ones that don't have deps
|
||||
for u in in_block.arg.lst:
|
||||
if u not in in_degree: push(u)
|
||||
|
||||
newlst = []
|
||||
while queue:
|
||||
_,_,x = heapq.heappop(queue)
|
||||
newlst.append(x)
|
||||
for u in local_children[x]:
|
||||
in_degree[u] -= 1
|
||||
if in_degree[u] == 0: push(u)
|
||||
assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}"
|
||||
return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst)))
|
||||
|
||||
pm_block_reorder = PatternMatcher([(UPat(Ops.BLOCK, name="in_block"), block_reorder),])
|
||||
|
||||
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
|
||||
@@ -169,6 +214,9 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
for u in v: new_forks[u] = out
|
||||
sink = sink.substitute(new_forks)
|
||||
|
||||
# reorder ops in block for speed
|
||||
sink = graph_rewrite(sink, pm_block_reorder, ctx={})
|
||||
|
||||
# final rewrite to merge all blocks into one
|
||||
sink = graph_rewrite(sink, pm_block_merge, ctx=children)
|
||||
|
||||
|
||||
@@ -127,8 +127,9 @@ class Ops(FastEnum):
|
||||
# UnaryOps
|
||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
|
||||
# loads before math
|
||||
# load/store before math
|
||||
LOAD = auto()
|
||||
STORE = auto()
|
||||
|
||||
# early INDEX
|
||||
INDEX = auto()
|
||||
@@ -144,7 +145,6 @@ class Ops(FastEnum):
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
|
||||
# assignment ops
|
||||
STORE = auto()
|
||||
ASSIGN = auto()
|
||||
BIND = auto()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user