mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-05 20:24:57 -05:00
fuzz uop schedules (#5345)
* basic blocks + cleanups * fixups * elif is better for future me * fuzz_schedule_max_paths * fix linter
This commit is contained in:
7
test/external/fuzz_schedule.py
vendored
7
test/external/fuzz_schedule.py
vendored
@@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
import numpy as np
|
||||
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar
|
||||
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
|
||||
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
|
||||
@@ -10,6 +10,7 @@ from tinygrad.ops import LoadOps
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
|
||||
ctx_vars = { MULTIOUTPUT: (0, 1) }
|
||||
FUZZ_SCHEDULE_MAX_PATHS = getenv("FUZZ_SCHEDULE_MAX_PATHS", 10)
|
||||
|
||||
def fuzz_schedule(outs:List[LazyBuffer]):
|
||||
# find toposorts across all tunable params
|
||||
@@ -73,7 +74,7 @@ def _exec_si(si:ScheduleItem, seed:int):
|
||||
ei.run()
|
||||
|
||||
T = TypeVar("T")
|
||||
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[Tuple[T, ...]]:
|
||||
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:Union[DefaultDict[T, int], Dict[T, int]]) -> List[Tuple[T, ...]]:
|
||||
visited: Set[T] = set()
|
||||
ret: List[Tuple[T, ...]] = []
|
||||
path: List[T] = []
|
||||
@@ -85,7 +86,7 @@ def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, i
|
||||
path.append(v)
|
||||
visited.add(v)
|
||||
recurse_paths(path)
|
||||
if len(ret) >= getenv("FUZZ_SCHEDULE_MAX_PATHS", 10): return
|
||||
if len(ret) >= FUZZ_SCHEDULE_MAX_PATHS: return
|
||||
# backtrack
|
||||
for u in graph[v]: in_degree[u] += 1
|
||||
path.pop()
|
||||
|
||||
Reference in New Issue
Block a user