mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fuzzer path search for DEFINE_ACC (#4656)
* insert acc * add test_ops * find toposorts * todo - not yet ready * remove the import * atol and childless children
This commit is contained in:
36
test/external/fuzz_uops.py
vendored
36
test/external/fuzz_uops.py
vendored
@@ -1,11 +1,10 @@
|
||||
import numpy as np
|
||||
from dataclasses import replace
|
||||
from typing import DefaultDict, Dict, List, Set
|
||||
from test.external.fuzz_schedule import find_all_toposorts
|
||||
from typing import DefaultDict, Dict, List, Set, Tuple
|
||||
from tinygrad.codegen.uops import UOp, UOpGraph, UOps
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.helpers import DEBUG, colored, getenv
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int], loops_children:Dict[UOp, Set[UOp]]):
|
||||
@@ -17,7 +16,7 @@ def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int]
|
||||
for u in path:
|
||||
if u.uop is UOps.IF: path.append(UOp(UOps.ENDIF, None, (u,)))
|
||||
if u.uop is UOps.RANGE:
|
||||
path.insert(max(path.index(x) for x in loops_children[u])+1, UOp(UOps.ENDRANGE, None, (u,)))
|
||||
path.insert(max(path.index(x) for x in loops_children[u] if x in path)+1, UOp(UOps.ENDRANGE, None, (u,)))
|
||||
return paths
|
||||
|
||||
class UOpsFuzzerRunner(CompiledRunner):
|
||||
@@ -44,8 +43,35 @@ class UOpsFuzzerRunner(CompiledRunner):
|
||||
super().__call__(rawbufs, var_vals, wait)
|
||||
for i, x in enumerate(rawbufs):
|
||||
try:
|
||||
np.testing.assert_allclose(np.frombuffer(x.as_buffer(), x.dtype.np), ground_truth[x])
|
||||
np.testing.assert_allclose(np.frombuffer(x.as_buffer(), x.dtype.np), ground_truth[x], atol=1e-6, rtol=1e-6)
|
||||
if DEBUG >= 2: print(colored(name, "green"))
|
||||
except AssertionError as e:
|
||||
print(colored(name, "red"))
|
||||
raise e
|
||||
|
||||
def find_all_toposorts(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int]) -> List[Tuple[UOp, ...]]:
|
||||
visited: Set[UOp] = set()
|
||||
ret: List[Tuple[UOp, ...]] = []
|
||||
path: List[UOp] = []
|
||||
|
||||
def recurse_paths(path:List[UOp]):
|
||||
for v, d in in_degree.items():
|
||||
if d != 0 or v in visited: continue
|
||||
if v.uop is UOps.DEFINE_ACC and any(l not in path for l in v.vin): continue
|
||||
for u in graph[v]: in_degree[u] -= 1
|
||||
if v.uop is UOps.DEFINE_ACC: path.insert(min(path.index(l) for l in v.vin), v)
|
||||
else: path.append(v)
|
||||
visited.add(v)
|
||||
recurse_paths(path)
|
||||
if len(ret) >= getenv("FUZZ_UOPS_MAX_PATHS", 10): return
|
||||
# backtrack
|
||||
for u in graph[v]: in_degree[u] += 1
|
||||
path.pop()
|
||||
visited.remove(v)
|
||||
if len(path) == len(in_degree): ret.append(tuple(path))
|
||||
recurse_paths(path)
|
||||
|
||||
if len(ret) == 0: raise RuntimeError("detected cycle in the graph")
|
||||
# verify all paths are unique
|
||||
assert len(ret) == len(set(ret))
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user