mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
clean up rewrite logic + merge siblings (#8026)
* clean up rewrite logic [pr] * simpler * merge sibling blocks * no PR
This commit is contained in:
@@ -2,12 +2,11 @@ from __future__ import annotations
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite
|
||||
from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.helpers import dedup, flatten, partition
|
||||
|
||||
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST,
|
||||
Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
|
||||
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
|
||||
|
||||
def disp(y:UOp) -> str:
|
||||
if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
|
||||
@@ -29,32 +28,44 @@ def append_to_block(ctx:Tuple[Dict[UOp, Tuple[UOp, ...]], Dict[UOp, List[UOp]]],
|
||||
block_ctxs, children = ctx
|
||||
new_srcs: List[UOp] = []
|
||||
to_append: List[UOp] = []
|
||||
old_blocks: Dict[Tuple[UOp, ...], UOp] = {}
|
||||
new_blocks: Dict[Tuple[UOp, ...], List[UOp]] = {}
|
||||
bb: BasicBlock = x.arg
|
||||
in_this_block = set(bb.lst)
|
||||
in_this_block = set(x.arg.lst)
|
||||
for u in x.src:
|
||||
if u.op in DONT_PLACE_IN_BLOCK or len([y for y in children[u] if y not in in_this_block]) > 0:
|
||||
# if it's a fork or not placed, we don't place it
|
||||
new_srcs.append(u)
|
||||
elif (block_ctx:=block_ctxs[u]) == bb.ctx:
|
||||
# if it's the same context, we place the UOp in this block and append the parents to it's srcs
|
||||
new_srcs += list(u.src)
|
||||
to_append.append(u)
|
||||
if u.op is Ops.BLOCK:
|
||||
# merge sibling blocks. NOTE: blocks must only have one output source
|
||||
assert u.arg.ctx not in old_blocks, "sibiling should never have been created"
|
||||
old_blocks[u.arg.ctx] = u
|
||||
elif u.op not in DONT_PLACE_IN_BLOCK and set(children[u]).issubset(in_this_block):
|
||||
# if it can go in blocks and all its children are in the block, we add it to the block
|
||||
if (block_ctx:=block_ctxs[u]) == x.arg.ctx:
|
||||
# if it's the same context, we place the UOp in this block and append the parents to its srcs
|
||||
new_srcs += list(u.src)
|
||||
to_append.append(u)
|
||||
else:
|
||||
# if it's a different context, we create a new block with this UOp
|
||||
new_blocks.setdefault(block_ctx, []).append(u)
|
||||
else:
|
||||
# otherwise, we create a new block with this UOp
|
||||
new_blocks.setdefault(block_ctx, []).append(u)
|
||||
# otherwise, we keep it in the srcs
|
||||
new_srcs.append(u)
|
||||
if len(to_append) == 0 and len(new_blocks) == 0: return None
|
||||
|
||||
for rng,lst in new_blocks.items():
|
||||
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(flatten(y.src for y in lst))), BasicBlock(rng, tuple(lst)))
|
||||
srcs = flatten(y.src for y in lst)
|
||||
if (old_block:=old_blocks.get(rng, None)) is not None:
|
||||
# NOTE: order shouldn't matter here
|
||||
srcs += list(old_block.src)
|
||||
lst += list(old_block.arg.lst)
|
||||
del old_blocks[rng]
|
||||
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst)))
|
||||
lrng = list(rng)
|
||||
for r in rng[::-1]:
|
||||
if r not in bb.ctx and r.op is not Ops.BLOCKSTART:
|
||||
if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
|
||||
lrng.remove(r)
|
||||
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
|
||||
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
|
||||
new_srcs.append(new_block)
|
||||
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(new_srcs)), BasicBlock(bb.ctx, tuple(to_append)+bb.lst))
|
||||
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
|
||||
|
||||
make_basic_blocks = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))),
|
||||
|
||||
@@ -173,6 +173,7 @@ class GroupOp:
|
||||
# meta ops
|
||||
Meta = {Ops.COPY, Ops.EMPTY, Ops.BUFFER_VIEW}
|
||||
Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
|
||||
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
|
||||
|
||||
# BinaryOps that can be flipped
|
||||
Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}
|
||||
|
||||
Reference in New Issue
Block a user