try blocks

This commit is contained in:
qazal
2024-05-18 16:49:50 +03:00
parent 00dcb1ff4d
commit 25f8e3fe85

View File

@@ -1,41 +1,48 @@
import heapq
from collections import defaultdict, deque
from typing import DefaultDict, Dict, List, Set, Tuple
from test.external.fuzz_schedule import find_all_toposorts
from tinygrad.codegen.uops import UOp, UOps
def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int], queue:List[Tuple[int, UOp]], push, loops_children:Dict[UOp, Set[UOp]]):
uops = find_all_topsorts(graph, in_degree, queue, push, loops_children)[-1]
uops: List[UOp] = []
visited: Set[UOp] = set()
blocks: List[List[UOp]] = []
for root, subgraph in loops_children.items():
blocks.append(block:=[root, *subgraph])
visited = visited.union(block)
blocks.append([])
for node in in_degree:
if node in visited: continue
blocks[-1].append(node)
visited.add(node)
for b in blocks: uops.extend(*get_paths(b))
print("---------------")
for u in uops: print(u)
return uops
MAX_PATHS = 2
def find_all_topsorts(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int], queue:List[Tuple[int, UOp]], push, loops_children:Dict[UOp, Set[UOp]]):
visited: Set[UOp] = set()
ret: List[Tuple[UOp, ...]] = []
def get_paths(block:List[UOp]) -> List[List[UOp]]:
graph: DefaultDict[UOp, List[UOp]] = defaultdict(list)
in_degree: Dict[UOp, int] = {}
for u in block:
in_degree[u] = 0
for x in u.vin:
graph[x].append(u)
in_degree[u] += 1
if x not in in_degree: in_degree[x] = 0
queue = deque(x for x, deg in in_degree.items() if deg == 0)
path: List[UOp] = []
ifs: List[UOp] = []
global_bufs: List[UOp] = []
# find a path
while queue:
_,x = heapq.heappop(queue)
if in_degree[x] != 0 or x in visited: continue
# insert to path
path.append(x)
visited.add(x)
for u in graph[x]:
in_degree[u] -= 1
push(u)
n = queue.popleft()
path.append(n)
for x in graph[n]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
# modify the path
for x in path:
if x.uop is UOps.DEFINE_ACC and len(x.vin):
path.remove(x)
path.insert(min(path.index(l) for l in x.vin), x)
elif x.uop is UOps.IF: path.insert(len(path)-1, UOp(UOps.ENDIF, None, (x,)))
for u, ss in loops_children.items():
last_op = max(path.index(s) for s in ss)
path.insert(last_op+1, UOp(UOps.ENDLOOP, None, (u,)))
# add to paths
assert path[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {path[-1]}"
ret.append(tuple(path[:-1]))
return ret
if x.uop is UOps.SINK: path.remove(x)
assert all(degree == 0 for u, degree in in_degree.items() if u.uop is not UOps.SINK)
return [path]