mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-05 12:15:05 -05:00
included non-reduce kernel and kernel with variables. green msg when everything passed it's possible that creating rawbufs failed due to memory error, included that in failure cases
167 lines
5.8 KiB
Python
167 lines
5.8 KiB
Python
import random, traceback, ctypes
|
|
from typing import List, Tuple
|
|
import numpy as np
|
|
from collections import defaultdict
|
|
from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
|
from tinygrad.codegen.linearizer import Linearizer
|
|
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.features.graph import print_tree
|
|
from tinygrad.helpers import getenv, from_mv, prod, colored, Context
|
|
from tinygrad.device import Device, Compiled
|
|
from tinygrad.codegen.linearizer import UOp
|
|
|
|
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
|
|
|
|
device = Device[Device.DEFAULT]
|
|
|
|
def get_fuzz_rawbufs(lin):
|
|
rawbufs = bufs_from_lin(lin)
|
|
|
|
# Reallocate output buffer with additional area to detect out-of-bounds writes.
|
|
RED_AREA_SIZE = 1024 if isinstance(device, Compiled) else 0
|
|
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True, size=rawbufs[0].size+RED_AREA_SIZE)
|
|
with Context(DEBUG=0):
|
|
for rawbuf in rawbufs[1:]:
|
|
t = Tensor.uniform((rawbuf.size,), dtype=rawbuf.dtype)
|
|
rawbuf.copyin(t.realize().lazydata.realized.as_buffer())
|
|
return rawbufs
|
|
|
|
def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None):
|
|
rawbuf = type(rawbuf)(Device.DEFAULT, rawbuf.size if size is None else size, rawbuf.dtype)
|
|
if zero:
|
|
with Context(DEBUG=0):
|
|
mv = memoryview(bytearray(rawbuf.size * rawbuf.dtype.itemsize))
|
|
ctypes.memset(from_mv(mv), 0, len(mv))
|
|
rawbuf.copyin(mv)
|
|
return rawbuf
|
|
|
|
def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
|
|
if rawbufs is None: rawbufs = bufs_from_lin(lin)
|
|
if var_vals is None: var_vals = {v: v.min for v in lin.ast.vars()}
|
|
|
|
# TODO: images needs required_optimization
|
|
try:
|
|
if isinstance(device, Compiled):
|
|
prg = device.to_program(lin)
|
|
else:
|
|
prg = device.get_runner(lin.ast)
|
|
except Exception:
|
|
print(lin.ast)
|
|
print(lin.applied_opts)
|
|
traceback.print_exc()
|
|
print("COMPILE FAILED!!")
|
|
return "COMPILE_ERROR"
|
|
|
|
try:
|
|
prg.exec(rawbufs, var_vals)
|
|
except Exception:
|
|
print(lin.ast)
|
|
print(lin.applied_opts)
|
|
traceback.print_exc()
|
|
print("EXEC FAILED!!")
|
|
return "EXEC_ERROR"
|
|
|
|
return "PASS"
|
|
|
|
|
|
def fuzz_linearizer(lin: Linearizer):
|
|
random.seed(42)
|
|
np.random.seed(42)
|
|
print_tree(lin.ast)
|
|
print(lin.colored_shape())
|
|
seen_uops = {}
|
|
last_lins = [lin]
|
|
failures = defaultdict(list)
|
|
|
|
FUZZ_BEAM = getenv("FUZZ_BEAM", 0)
|
|
FUZZ_MAX_SIZE = getenv("FUZZ_MAX_SIZE", 0)
|
|
if FUZZ_MAX_SIZE > 0 and prod(lin.full_shape) > FUZZ_MAX_SIZE:
|
|
print("skipping large kernel")
|
|
return failures
|
|
|
|
# get baseline unoptimized output
|
|
unoptimized = Linearizer(lin.ast)
|
|
var_vals = {v: random.randint(v.min, v.max) for v in lin.ast.vars()}
|
|
|
|
try:
|
|
rawbufs = get_fuzz_rawbufs(lin)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
print("RAWBUFS FAILED!!")
|
|
failures["RAWBUFS_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
|
|
return failures
|
|
|
|
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
|
|
failures["BASELINE_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
|
|
return failures
|
|
ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np).copy()
|
|
|
|
for depth in range(getenv("DEPTH", 1 if FUZZ_BEAM else 10)):
|
|
next_lins = []
|
|
for lin in last_lins:
|
|
actions = get_linearizer_actions(lin, include_0=False)
|
|
if FUZZ_BEAM: print(f"testing {lin.applied_opts=} with {len(actions)} actions")
|
|
if not actions: continue
|
|
|
|
test_lins = list(actions.values())
|
|
if not FUZZ_BEAM: test_lins = [random.choice(test_lins)]
|
|
|
|
for test_lin in test_lins:
|
|
if not FUZZ_BEAM and test_lin.applied_opts: print(f"applied opts: {test_lin.applied_opts}")
|
|
|
|
# stop if kernel uops repeat
|
|
tuops = tuplize_uops(test_lin.linearize().uops.uops)
|
|
if tuops in seen_uops:
|
|
continue
|
|
seen_uops[tuops] = tuple(test_lin.applied_opts)
|
|
|
|
if not FUZZ_BEAM: print(test_lin.colored_shape())
|
|
# get a new output buffer
|
|
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True)
|
|
var_vals = {v: random.randint(v.min, v.max) for v in test_lin.ast.vars()}
|
|
if (msg := run_linearizer(test_lin, rawbufs, var_vals)) != "PASS":
|
|
failures[msg].append((test_lin.ast, test_lin.applied_opts))
|
|
continue
|
|
|
|
result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np)
|
|
try:
|
|
# compare memoryviews directly
|
|
np.testing.assert_allclose(result, ground_truth, rtol=1e-2, atol=1e-2)
|
|
except AssertionError:
|
|
print(test_lin.ast)
|
|
print(test_lin.applied_opts)
|
|
traceback.print_exc()
|
|
print("COMPARE FAILED!!")
|
|
failures["COMPARE_ERROR"].append((test_lin.ast, test_lin.applied_opts))
|
|
continue
|
|
next_lins.append(test_lin)
|
|
|
|
last_lins = next_lins
|
|
if FUZZ_BEAM: print(f"depth={depth} total_lins={len(last_lins)} {failures=}")
|
|
return failures
|
|
|
|
if __name__ == "__main__":
|
|
ast_strs = load_worlds(filter_reduce=False, filter_novariable=False)
|
|
print(f"{len(ast_strs)=}")
|
|
tested = 0
|
|
failures = defaultdict(list)
|
|
for i, ast in enumerate(ast_strs[:getenv("FUZZ_N", len(ast_strs))]):
|
|
if "dtypes.image" in ast and Device.DEFAULT != "GPU": continue # IMAGE is only for GPU
|
|
print(f"testing ast {i}")
|
|
tested += 1
|
|
lin = ast_str_to_lin(ast)
|
|
for k, v in fuzz_linearizer(lin).items():
|
|
for f in v:
|
|
failures[k].append(f)
|
|
for msg, errors in failures.items():
|
|
for i, (ast, opts) in enumerate(errors):
|
|
print(f"{msg} {i} AST: {ast}")
|
|
print(f"{msg} {i} OPTS: {opts}\n")
|
|
print(f"{tested=}")
|
|
if failures:
|
|
for msg, errors in failures.items():
|
|
print(f"{msg}: {len(errors)}")
|
|
else:
|
|
print(colored("all passed", "green"))
|