clean up rewrite logic + merge siblings (#8026)

* clean up rewrite logic [pr]

* simpler

* merge sibling blocks

* no PR
This commit is contained in:
George Hotz
2024-12-04 13:26:16 +08:00
committed by GitHub
parent 004b2ecff5
commit fdd1e56827
2 changed files with 29 additions and 17 deletions

View File

@@ -2,12 +2,11 @@ from __future__ import annotations
from typing import List, Dict, Tuple, Optional
import collections
from dataclasses import dataclass
from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite
from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import dedup, flatten, partition
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST,
Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
def disp(y:UOp) -> str:
if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
@@ -29,32 +28,44 @@ def append_to_block(ctx:Tuple[Dict[UOp, Tuple[UOp, ...]], Dict[UOp, List[UOp]]],
block_ctxs, children = ctx
new_srcs: List[UOp] = []
to_append: List[UOp] = []
old_blocks: Dict[Tuple[UOp, ...], UOp] = {}
new_blocks: Dict[Tuple[UOp, ...], List[UOp]] = {}
bb: BasicBlock = x.arg
in_this_block = set(bb.lst)
in_this_block = set(x.arg.lst)
for u in x.src:
if u.op in DONT_PLACE_IN_BLOCK or len([y for y in children[u] if y not in in_this_block]) > 0:
# if it's a fork or not placed, we don't place it
new_srcs.append(u)
elif (block_ctx:=block_ctxs[u]) == bb.ctx:
# if it's the same context, we place the UOp in this block and append the parents to it's srcs
new_srcs += list(u.src)
to_append.append(u)
if u.op is Ops.BLOCK:
# merge sibling blocks. NOTE: blocks must only have one output source
assert u.arg.ctx not in old_blocks, "sibiling should never have been created"
old_blocks[u.arg.ctx] = u
elif u.op not in DONT_PLACE_IN_BLOCK and set(children[u]).issubset(in_this_block):
# if it can go in blocks and all its children are in the block, we add it to the block
if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
# if it's the same context, we place the UOp in this block and append the parents to its srcs
new_srcs += list(u.src)
to_append.append(u)
else:
# if it's a different context, we create a new block with this UOp
new_blocks.setdefault(block_ctx, []).append(u)
else:
# otherwise, we create a new block with this UOp
new_blocks.setdefault(block_ctx, []).append(u)
# otherwise, we keep it in the srcs
new_srcs.append(u)
if len(to_append) == 0 and len(new_blocks) == 0: return None
for rng,lst in new_blocks.items():
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(flatten(y.src for y in lst))), BasicBlock(rng, tuple(lst)))
srcs = flatten(y.src for y in lst)
if (old_block:=old_blocks.get(rng, None)) is not None:
# NOTE: order shouldn't matter here
srcs += list(old_block.src)
lst += list(old_block.arg.lst)
del old_blocks[rng]
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst)))
lrng = list(rng)
for r in rng[::-1]:
if r not in bb.ctx and r.op is not Ops.BLOCKSTART:
if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
lrng.remove(r)
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
new_srcs.append(new_block)
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(new_srcs)), BasicBlock(bb.ctx, tuple(to_append)+bb.lst))
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
make_basic_blocks = PatternMatcher([
(UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))),

View File

@@ -173,6 +173,7 @@ class GroupOp:
# meta ops
Meta = {Ops.COPY, Ops.EMPTY, Ops.BUFFER_VIEW}
Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
# BinaryOps that can be flipped
Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}