mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
fuzz uop schedules (#5345)
* basic blocks + cleanups * fixups * elif is better for future me * fuzz_schedule_max_paths * fix linter
This commit is contained in:
7
test/external/fuzz_schedule.py
vendored
7
test/external/fuzz_schedule.py
vendored
@@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
import numpy as np
|
||||
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar
|
||||
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
|
||||
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
|
||||
@@ -10,6 +10,7 @@ from tinygrad.ops import LoadOps
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
|
||||
ctx_vars = { MULTIOUTPUT: (0, 1) }
|
||||
FUZZ_SCHEDULE_MAX_PATHS = getenv("FUZZ_SCHEDULE_MAX_PATHS", 10)
|
||||
|
||||
def fuzz_schedule(outs:List[LazyBuffer]):
|
||||
# find toposorts across all tunable params
|
||||
@@ -73,7 +74,7 @@ def _exec_si(si:ScheduleItem, seed:int):
|
||||
ei.run()
|
||||
|
||||
T = TypeVar("T")
|
||||
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[Tuple[T, ...]]:
|
||||
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:Union[DefaultDict[T, int], Dict[T, int]]) -> List[Tuple[T, ...]]:
|
||||
visited: Set[T] = set()
|
||||
ret: List[Tuple[T, ...]] = []
|
||||
path: List[T] = []
|
||||
@@ -85,7 +86,7 @@ def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, i
|
||||
path.append(v)
|
||||
visited.add(v)
|
||||
recurse_paths(path)
|
||||
if len(ret) >= getenv("FUZZ_SCHEDULE_MAX_PATHS", 10): return
|
||||
if len(ret) >= FUZZ_SCHEDULE_MAX_PATHS: return
|
||||
# backtrack
|
||||
for u in graph[v]: in_degree[u] += 1
|
||||
path.pop()
|
||||
|
||||
78
test/external/fuzz_uops.py
vendored
78
test/external/fuzz_uops.py
vendored
@@ -1,40 +1,55 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
from dataclasses import replace
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from tinygrad.codegen.uops import UOp, UOpGraph, UOps
|
||||
from typing import DefaultDict, Dict, List, Tuple
|
||||
from tinygrad.codegen.uops import END_FOR_UOP, UOp, UOpGraph
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import DEBUG, colored, getenv
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts
|
||||
|
||||
def fuzz_uops(graph:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int], loops_children:Dict[UOp, Set[UOp]]):
|
||||
paths: List[List[UOp]] = []
|
||||
# TODO: express DEFINE_ACC and loop children conditions in the graph, builtin.
|
||||
for p in find_all_toposorts(graph, in_degree):
|
||||
assert p[-1].op is UOps.SINK, f"didn't end with SINK, ended with {p[-1]}"
|
||||
paths.append(path:=list(p[:-1]))
|
||||
for u in path:
|
||||
if u.op is UOps.IF: path.append(UOp(UOps.ENDIF, None, (u,)))
|
||||
if u.op is UOps.RANGE:
|
||||
path.insert(max(path.index(x) for x in loops_children[u] if x in path)+1, UOp(UOps.ENDRANGE, None, (u,)))
|
||||
return paths
|
||||
def fuzz_uops(uops:UOpGraph) -> List[Tuple[UOp, ...]]:
|
||||
blocks: List[List[UOp]] = [[]]
|
||||
for u in uops:
|
||||
if u.op in END_FOR_UOP: blocks.append([u])
|
||||
elif u.op in {x[1] for x in END_FOR_UOP.values()}: blocks.extend([[u], []])
|
||||
else: blocks[-1].append(u)
|
||||
|
||||
paths_for_block: Dict[int, List[Tuple[UOp, ...]]] = {}
|
||||
for bi, bb in enumerate(blocks):
|
||||
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
||||
in_degree: Dict[UOp, int] = {}
|
||||
for u in bb:
|
||||
in_degree[u] = 0
|
||||
for x in u.src:
|
||||
if x in bb:
|
||||
children[x].append(u)
|
||||
in_degree[u] += 1
|
||||
paths_for_block[bi] = find_all_toposorts(children, in_degree)
|
||||
paths: Dict[Tuple[UOp, ...], None] = {}
|
||||
for up in itertools.product(*paths_for_block.values()):
|
||||
paths[tuple(uop for path in up for uop in path)] = None
|
||||
if len(paths) >= FUZZ_SCHEDULE_MAX_PATHS: break
|
||||
return list(paths)
|
||||
|
||||
class UOpsFuzzerRunner(CompiledRunner):
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
|
||||
assert self.p.uops is not None and len(self.p.uops.fuzz_paths) >= 1
|
||||
assert self.p.uops is not None and len(self.p.uops._fuzz_paths) >= 1
|
||||
init_rawbufs, init_name = {x:x.as_buffer() for x in rawbufs}, self.p.function_name
|
||||
init_globals = {i[0]:buf for i, buf in zip(self.p.globals, rawbufs)}
|
||||
if DEBUG >= 1: print(colored(f"fuzzing {len(self.p.uops.fuzz_paths)} UOps permutations for {init_name}", "yellow"))
|
||||
if DEBUG >= 1: print(colored(f"fuzzing {len(self.p.uops._fuzz_paths)} uop permutations for {init_name}", "yellow"))
|
||||
|
||||
super().__call__(rawbufs, var_vals, wait)
|
||||
ground_truth = {x:np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in rawbufs}
|
||||
|
||||
for i, path in enumerate(self.p.uops.fuzz_paths):
|
||||
for i, path in enumerate(self.p.uops._fuzz_paths):
|
||||
# setup prg
|
||||
uops = UOpGraph([])
|
||||
uops._uops = list(path)
|
||||
if DEBUG >= 6: uops.print()
|
||||
if DEBUG >= 5: uops.print()
|
||||
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops)
|
||||
if DEBUG >= 4: print(self.p.src)
|
||||
self.lib = Device[self.p.dname].compiler.compile_cached(self.p.src)
|
||||
@@ -49,30 +64,3 @@ class UOpsFuzzerRunner(CompiledRunner):
|
||||
except AssertionError as e:
|
||||
print(colored(name, "red"))
|
||||
raise e
|
||||
|
||||
def find_all_toposorts(graph:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]) -> List[Tuple[UOp, ...]]:
|
||||
visited: Set[UOp] = set()
|
||||
ret: List[Tuple[UOp, ...]] = []
|
||||
path: List[UOp] = []
|
||||
|
||||
def recurse_paths(path:List[UOp]):
|
||||
for v, d in in_degree.items():
|
||||
if d != 0 or v in visited: continue
|
||||
if v.op is UOps.DEFINE_ACC and any(l not in path for l in v.src): continue
|
||||
for u in graph[v]: in_degree[u] -= 1
|
||||
if v.op is UOps.DEFINE_ACC: path.insert(min(path.index(l) for l in v.src), v)
|
||||
else: path.append(v)
|
||||
visited.add(v)
|
||||
recurse_paths(path)
|
||||
if len(ret) >= getenv("FUZZ_UOPS_MAX_PATHS", 10): return
|
||||
# backtrack
|
||||
for u in graph[v]: in_degree[u] += 1
|
||||
path.pop()
|
||||
visited.remove(v)
|
||||
if len(path) == len(in_degree): ret.append(tuple(path))
|
||||
recurse_paths(path)
|
||||
|
||||
if len(ret) == 0: raise RuntimeError("detected cycle in the graph")
|
||||
# verify all paths are unique
|
||||
assert len(ret) == len(set(ret))
|
||||
return ret
|
||||
|
||||
@@ -26,6 +26,8 @@ class UOps(Enum):
|
||||
# these two are not graph nodes
|
||||
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
||||
|
||||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
|
||||
|
||||
def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class UOp:
|
||||
@@ -368,9 +370,9 @@ class UOpGraph:
|
||||
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]
|
||||
|
||||
@property
|
||||
def uops(self):
|
||||
def uops(self) -> List[UOp]:
|
||||
if self._uops is None: self.linearize()
|
||||
return self._uops
|
||||
return cast(List[UOp], self._uops)
|
||||
|
||||
def graph(self):
|
||||
from tinygrad.engine.graph import graph_uops
|
||||
@@ -412,8 +414,7 @@ class UOpGraph:
|
||||
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
|
||||
|
||||
# scope children impact the toposort and END* insertion
|
||||
end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
|
||||
scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in reversed(in_degree) if p.op in end_for_uop}
|
||||
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
|
||||
|
||||
queue:List[Tuple[int, UOp]] = []
|
||||
def push(u:UOp):
|
||||
@@ -426,10 +427,6 @@ class UOpGraph:
|
||||
for u in children:
|
||||
if in_degree[u] == 0: push(u)
|
||||
|
||||
if getenv("FUZZ_UOPS", 0):
|
||||
from test.external.fuzz_uops import fuzz_uops
|
||||
self.fuzz_paths = fuzz_uops(children, in_degree.copy(), scope_children)
|
||||
|
||||
self._uops = []
|
||||
while queue:
|
||||
p,x = heapq.heappop(queue)
|
||||
@@ -443,11 +440,14 @@ class UOpGraph:
|
||||
if in_degree[u] == 0: push(u)
|
||||
|
||||
for u in (self._uops):
|
||||
if u.op in end_for_uop: self._uops.insert(max([self._uops.index(l) for l in scope_children[u]])+1, UOp(end_for_uop[u.op][1], None, (u,)))
|
||||
if u.op in END_FOR_UOP: self._uops.insert(max([self._uops.index(l) for l in scope_children[u]])+1, UOp(END_FOR_UOP[u.op][1], None, (u,)))
|
||||
|
||||
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
|
||||
self._uops = self._uops[:-1]
|
||||
|
||||
if getenv("FUZZ_UOPS"):
|
||||
from test.external.fuzz_uops import fuzz_uops
|
||||
self._fuzz_paths = fuzz_uops(self)
|
||||
if do_type_verify: type_verify(self.uops)
|
||||
|
||||
# *** checker functions ***
|
||||
|
||||
@@ -136,7 +136,7 @@ def get_runner(dname:str, ast:Tuple[LazyOp, ...]) -> CompiledRunner:
|
||||
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
|
||||
else:
|
||||
prg: Program = get_linearizer(Device[dname].renderer, ast).to_program()
|
||||
if hasattr(prg.uops, "fuzz_paths"):
|
||||
if hasattr(prg.uops, "_fuzz_paths"):
|
||||
from test.external.fuzz_uops import UOpsFuzzerRunner
|
||||
return UOpsFuzzerRunner(replace(prg, dname=dname))
|
||||
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
|
||||
|
||||
Reference in New Issue
Block a user