fuzz uop schedules (#5345)

* basic blocks + cleanups

* fixups

* elif is better for future me

* fuzz_schedule_max_paths

* fix linter
This commit is contained in:
qazal
2024-07-09 15:24:56 +03:00
committed by GitHub
parent d5a68ae6b3
commit bee96a19ff
4 changed files with 47 additions and 58 deletions

View File

@@ -1,6 +1,6 @@
import itertools
import numpy as np
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union
from tinygrad.device import Buffer
from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
@@ -10,6 +10,7 @@ from tinygrad.ops import LoadOps
from tinygrad.tensor import Tensor, _to_np_dtype
ctx_vars = { MULTIOUTPUT: (0, 1) }
FUZZ_SCHEDULE_MAX_PATHS = getenv("FUZZ_SCHEDULE_MAX_PATHS", 10)
def fuzz_schedule(outs:List[LazyBuffer]):
# find toposorts across all tunable params
@@ -73,7 +74,7 @@ def _exec_si(si:ScheduleItem, seed:int):
ei.run()
T = TypeVar("T")
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[Tuple[T, ...]]:
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:Union[DefaultDict[T, int], Dict[T, int]]) -> List[Tuple[T, ...]]:
visited: Set[T] = set()
ret: List[Tuple[T, ...]] = []
path: List[T] = []
@@ -85,7 +86,7 @@ def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, i
path.append(v)
visited.add(v)
recurse_paths(path)
if len(ret) >= getenv("FUZZ_SCHEDULE_MAX_PATHS", 10): return
if len(ret) >= FUZZ_SCHEDULE_MAX_PATHS: return
# backtrack
for u in graph[v]: in_degree[u] += 1
path.pop()

View File

@@ -1,40 +1,55 @@
import itertools
from collections import defaultdict
import numpy as np
from dataclasses import replace
from typing import Dict, List, Set, Tuple
from tinygrad.codegen.uops import UOp, UOpGraph, UOps
from typing import DefaultDict, Dict, List, Tuple
from tinygrad.codegen.uops import END_FOR_UOP, UOp, UOpGraph
from tinygrad.device import Buffer, Device
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import DEBUG, colored, getenv
from tinygrad.helpers import DEBUG, colored
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import _to_np_dtype
from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts
def fuzz_uops(graph:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int], loops_children:Dict[UOp, Set[UOp]]):
paths: List[List[UOp]] = []
# TODO: express DEFINE_ACC and loop children conditions in the graph, builtin.
for p in find_all_toposorts(graph, in_degree):
assert p[-1].op is UOps.SINK, f"didn't end with SINK, ended with {p[-1]}"
paths.append(path:=list(p[:-1]))
for u in path:
if u.op is UOps.IF: path.append(UOp(UOps.ENDIF, None, (u,)))
if u.op is UOps.RANGE:
path.insert(max(path.index(x) for x in loops_children[u] if x in path)+1, UOp(UOps.ENDRANGE, None, (u,)))
return paths
def fuzz_uops(uops:UOpGraph) -> List[Tuple[UOp, ...]]:
blocks: List[List[UOp]] = [[]]
for u in uops:
if u.op in END_FOR_UOP: blocks.append([u])
elif u.op in {x[1] for x in END_FOR_UOP.values()}: blocks.extend([[u], []])
else: blocks[-1].append(u)
paths_for_block: Dict[int, List[Tuple[UOp, ...]]] = {}
for bi, bb in enumerate(blocks):
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
in_degree: Dict[UOp, int] = {}
for u in bb:
in_degree[u] = 0
for x in u.src:
if x in bb:
children[x].append(u)
in_degree[u] += 1
paths_for_block[bi] = find_all_toposorts(children, in_degree)
paths: Dict[Tuple[UOp, ...], None] = {}
for up in itertools.product(*paths_for_block.values()):
paths[tuple(uop for path in up for uop in path)] = None
if len(paths) >= FUZZ_SCHEDULE_MAX_PATHS: break
return list(paths)
class UOpsFuzzerRunner(CompiledRunner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
assert self.p.uops is not None and len(self.p.uops.fuzz_paths) >= 1
assert self.p.uops is not None and len(self.p.uops._fuzz_paths) >= 1
init_rawbufs, init_name = {x:x.as_buffer() for x in rawbufs}, self.p.function_name
init_globals = {i[0]:buf for i, buf in zip(self.p.globals, rawbufs)}
if DEBUG >= 1: print(colored(f"fuzzing {len(self.p.uops.fuzz_paths)} UOps permutations for {init_name}", "yellow"))
if DEBUG >= 1: print(colored(f"fuzzing {len(self.p.uops._fuzz_paths)} uop permutations for {init_name}", "yellow"))
super().__call__(rawbufs, var_vals, wait)
ground_truth = {x:np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in rawbufs}
for i, path in enumerate(self.p.uops.fuzz_paths):
for i, path in enumerate(self.p.uops._fuzz_paths):
# setup prg
uops = UOpGraph([])
uops._uops = list(path)
if DEBUG >= 6: uops.print()
if DEBUG >= 5: uops.print()
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops)
if DEBUG >= 4: print(self.p.src)
self.lib = Device[self.p.dname].compiler.compile_cached(self.p.src)
@@ -49,30 +64,3 @@ class UOpsFuzzerRunner(CompiledRunner):
except AssertionError as e:
print(colored(name, "red"))
raise e
def find_all_toposorts(graph:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]) -> List[Tuple[UOp, ...]]:
visited: Set[UOp] = set()
ret: List[Tuple[UOp, ...]] = []
path: List[UOp] = []
def recurse_paths(path:List[UOp]):
for v, d in in_degree.items():
if d != 0 or v in visited: continue
if v.op is UOps.DEFINE_ACC and any(l not in path for l in v.src): continue
for u in graph[v]: in_degree[u] -= 1
if v.op is UOps.DEFINE_ACC: path.insert(min(path.index(l) for l in v.src), v)
else: path.append(v)
visited.add(v)
recurse_paths(path)
if len(ret) >= getenv("FUZZ_UOPS_MAX_PATHS", 10): return
# backtrack
for u in graph[v]: in_degree[u] += 1
path.pop()
visited.remove(v)
if len(path) == len(in_degree): ret.append(tuple(path))
recurse_paths(path)
if len(ret) == 0: raise RuntimeError("detected cycle in the graph")
# verify all paths are unique
assert len(ret) == len(set(ret))
return ret

View File

@@ -26,6 +26,8 @@ class UOps(Enum):
# these two are not graph nodes
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x
@dataclass(frozen=True, eq=False)
class UOp:
@@ -368,9 +370,9 @@ class UOpGraph:
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]
@property
def uops(self):
def uops(self) -> List[UOp]:
if self._uops is None: self.linearize()
return self._uops
return cast(List[UOp], self._uops)
def graph(self):
from tinygrad.engine.graph import graph_uops
@@ -412,8 +414,7 @@ class UOpGraph:
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
# scope children impact the toposort and END* insertion
end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in reversed(in_degree) if p.op in end_for_uop}
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
queue:List[Tuple[int, UOp]] = []
def push(u:UOp):
@@ -426,10 +427,6 @@ class UOpGraph:
for u in children:
if in_degree[u] == 0: push(u)
if getenv("FUZZ_UOPS", 0):
from test.external.fuzz_uops import fuzz_uops
self.fuzz_paths = fuzz_uops(children, in_degree.copy(), scope_children)
self._uops = []
while queue:
p,x = heapq.heappop(queue)
@@ -443,11 +440,14 @@ class UOpGraph:
if in_degree[u] == 0: push(u)
for u in (self._uops):
if u.op in end_for_uop: self._uops.insert(max([self._uops.index(l) for l in scope_children[u]])+1, UOp(end_for_uop[u.op][1], None, (u,)))
if u.op in END_FOR_UOP: self._uops.insert(max([self._uops.index(l) for l in scope_children[u]])+1, UOp(END_FOR_UOP[u.op][1], None, (u,)))
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
self._uops = self._uops[:-1]
if getenv("FUZZ_UOPS"):
from test.external.fuzz_uops import fuzz_uops
self._fuzz_paths = fuzz_uops(self)
if do_type_verify: type_verify(self.uops)
# *** checker functions ***

View File

@@ -136,7 +136,7 @@ def get_runner(dname:str, ast:Tuple[LazyOp, ...]) -> CompiledRunner:
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
else:
prg: Program = get_linearizer(Device[dname].renderer, ast).to_program()
if hasattr(prg.uops, "fuzz_paths"):
if hasattr(prg.uops, "_fuzz_paths"):
from test.external.fuzz_uops import UOpsFuzzerRunner
return UOpsFuzzerRunner(replace(prg, dname=dname))
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))