fuzz uop schedules (#5345)

* basic blocks + cleanups

* fixups

* elif is better for future me

* fuzz_schedule_max_paths

* fix linter
This commit is contained in:
qazal
2024-07-09 15:24:56 +03:00
committed by GitHub
parent d5a68ae6b3
commit bee96a19ff
4 changed files with 47 additions and 58 deletions

View File

@@ -1,6 +1,6 @@
import itertools
import numpy as np
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union
from tinygrad.device import Buffer
from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
@@ -10,6 +10,7 @@ from tinygrad.ops import LoadOps
from tinygrad.tensor import Tensor, _to_np_dtype
ctx_vars = { MULTIOUTPUT: (0, 1) }
FUZZ_SCHEDULE_MAX_PATHS = getenv("FUZZ_SCHEDULE_MAX_PATHS", 10)
def fuzz_schedule(outs:List[LazyBuffer]):
# find toposorts across all tunable params
@@ -73,7 +74,7 @@ def _exec_si(si:ScheduleItem, seed:int):
ei.run()
T = TypeVar("T")
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[Tuple[T, ...]]:
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:Union[DefaultDict[T, int], Dict[T, int]]) -> List[Tuple[T, ...]]:
visited: Set[T] = set()
ret: List[Tuple[T, ...]] = []
path: List[T] = []
@@ -85,7 +86,7 @@ def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, i
path.append(v)
visited.add(v)
recurse_paths(path)
if len(ret) >= getenv("FUZZ_SCHEDULE_MAX_PATHS", 10): return
if len(ret) >= FUZZ_SCHEDULE_MAX_PATHS: return
# backtrack
for u in graph[v]: in_degree[u] += 1
path.pop()