mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
try blocks
This commit is contained in:
69
test/external/fuzz_uops.py
vendored
69
test/external/fuzz_uops.py
vendored
@@ -1,41 +1,48 @@
|
||||
import heapq
|
||||
from collections import defaultdict, deque
|
||||
from typing import DefaultDict, Dict, List, Set, Tuple
|
||||
from test.external.fuzz_schedule import find_all_toposorts
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
|
||||
def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int], queue:List[Tuple[int, UOp]], push, loops_children:Dict[UOp, Set[UOp]]):
|
||||
uops = find_all_topsorts(graph, in_degree, queue, push, loops_children)[-1]
|
||||
uops: List[UOp] = []
|
||||
visited: Set[UOp] = set()
|
||||
blocks: List[List[UOp]] = []
|
||||
for root, subgraph in loops_children.items():
|
||||
blocks.append(block:=[root, *subgraph])
|
||||
visited = visited.union(block)
|
||||
|
||||
blocks.append([])
|
||||
for node in in_degree:
|
||||
if node in visited: continue
|
||||
blocks[-1].append(node)
|
||||
visited.add(node)
|
||||
|
||||
for b in blocks: uops.extend(*get_paths(b))
|
||||
|
||||
print("---------------")
|
||||
for u in uops: print(u)
|
||||
return uops
|
||||
|
||||
MAX_PATHS = 2
|
||||
def find_all_topsorts(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int], queue:List[Tuple[int, UOp]], push, loops_children:Dict[UOp, Set[UOp]]):
|
||||
visited: Set[UOp] = set()
|
||||
ret: List[Tuple[UOp, ...]] = []
|
||||
def get_paths(block:List[UOp]) -> List[List[UOp]]:
|
||||
graph: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
||||
in_degree: Dict[UOp, int] = {}
|
||||
for u in block:
|
||||
in_degree[u] = 0
|
||||
for x in u.vin:
|
||||
graph[x].append(u)
|
||||
in_degree[u] += 1
|
||||
if x not in in_degree: in_degree[x] = 0
|
||||
|
||||
queue = deque(x for x, deg in in_degree.items() if deg == 0)
|
||||
path: List[UOp] = []
|
||||
ifs: List[UOp] = []
|
||||
global_bufs: List[UOp] = []
|
||||
|
||||
# find a path
|
||||
while queue:
|
||||
_,x = heapq.heappop(queue)
|
||||
if in_degree[x] != 0 or x in visited: continue
|
||||
# insert to path
|
||||
path.append(x)
|
||||
visited.add(x)
|
||||
for u in graph[x]:
|
||||
in_degree[u] -= 1
|
||||
push(u)
|
||||
n = queue.popleft()
|
||||
path.append(n)
|
||||
for x in graph[n]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
# modify the path
|
||||
for x in path:
|
||||
if x.uop is UOps.DEFINE_ACC and len(x.vin):
|
||||
path.remove(x)
|
||||
path.insert(min(path.index(l) for l in x.vin), x)
|
||||
elif x.uop is UOps.IF: path.insert(len(path)-1, UOp(UOps.ENDIF, None, (x,)))
|
||||
for u, ss in loops_children.items():
|
||||
last_op = max(path.index(s) for s in ss)
|
||||
path.insert(last_op+1, UOp(UOps.ENDLOOP, None, (u,)))
|
||||
|
||||
# add to paths
|
||||
assert path[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {path[-1]}"
|
||||
ret.append(tuple(path[:-1]))
|
||||
return ret
|
||||
if x.uop is UOps.SINK: path.remove(x)
|
||||
assert all(degree == 0 for u, degree in in_degree.items() if u.uop is not UOps.SINK)
|
||||
return [path]
|
||||
|
||||
Reference in New Issue
Block a user