mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-06 04:35:00 -05:00
log optimized kernels and a script to compare with non-optimized ones (#3829)
* search: add BEAM_VERIFY option to validate search results refactor fuzz_linearizer comparison to allow it to be used in for BEAM_VERIFY in device.py * search: fix to verify the beam_search result and not the fastest * search: fix typing and clean up * device: remove imports from test and add LOGKERN options LOGKERN output can be used with test/external/verify_kernel.py to validate correctness * fix example in verify_kernel.py * cleanup fixes * fix to use f-strings
This commit is contained in:
85
test/external/fuzz_linearizer.py
vendored
85
test/external/fuzz_linearizer.py
vendored
@@ -1,15 +1,17 @@
|
||||
import random, traceback, ctypes
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, DefaultDict
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOp
|
||||
from tinygrad.codegen.kernel import Opt
|
||||
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.features.graph import print_tree
|
||||
from tinygrad.helpers import getenv, from_mv, prod, colored, Context
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.codegen.linearizer import UOp
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import LazyOp
|
||||
|
||||
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
|
||||
|
||||
@@ -24,7 +26,7 @@ def get_fuzz_rawbufs(lin):
|
||||
with Context(DEBUG=0):
|
||||
for rawbuf in rawbufs[1:]:
|
||||
t = Tensor.uniform((rawbuf.size,), dtype=rawbuf.dtype)
|
||||
rawbuf.copyin(t.realize().lazydata.realized.as_buffer())
|
||||
if isinstance(ld:=t.realize().lazydata, LazyBuffer) and ld.realized: rawbuf.copyin(ld.realized.as_buffer())
|
||||
return rawbufs
|
||||
|
||||
def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None):
|
||||
@@ -38,32 +40,43 @@ def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None):
|
||||
|
||||
def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
|
||||
if rawbufs is None: rawbufs = bufs_from_lin(lin)
|
||||
if var_vals is None: var_vals = {v: v.min for v in lin.ast.vars()}
|
||||
if var_vals is None: var_vals = {v: v.min for v in lin.ast[0].vars()}
|
||||
|
||||
# TODO: images needs required_optimization
|
||||
try:
|
||||
if isinstance(device, Compiled):
|
||||
prg = device.to_program(lin)
|
||||
else:
|
||||
prg = device.get_runner(lin.ast)
|
||||
prg = device.to_program(lin)
|
||||
except Exception:
|
||||
print(lin.ast)
|
||||
print(lin.applied_opts)
|
||||
traceback.print_exc()
|
||||
print("COMPILE FAILED!!")
|
||||
return "COMPILE_ERROR"
|
||||
|
||||
try:
|
||||
prg.exec(rawbufs, var_vals)
|
||||
prg(rawbufs, var_vals, wait=True, do_update_stats=False)
|
||||
except Exception:
|
||||
print(lin.ast)
|
||||
print(lin.applied_opts)
|
||||
traceback.print_exc()
|
||||
print("EXEC FAILED!!")
|
||||
return "EXEC_ERROR"
|
||||
|
||||
return "PASS"
|
||||
|
||||
def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
|
||||
try:
|
||||
if rawbufs is None:
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
else:
|
||||
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
|
||||
except BaseException:
|
||||
return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
if var_vals is None: var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()}
|
||||
if ground_truth is None:
|
||||
unoptimized = Linearizer(*lin.ast)
|
||||
unoptimized.required_optimizations()
|
||||
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
|
||||
return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np).copy()
|
||||
|
||||
if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
|
||||
return (run_msg, rawbufs, var_vals, ground_truth,)
|
||||
result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np)
|
||||
return ("PASS" if np.allclose(result, ground_truth, rtol=rtol, atol=atol) else "COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
|
||||
def fuzz_linearizer(lin: Linearizer):
|
||||
SEED = getenv("SEED", 42)
|
||||
@@ -73,7 +86,8 @@ def fuzz_linearizer(lin: Linearizer):
|
||||
print(lin.colored_shape())
|
||||
seen_uops = {}
|
||||
last_lins = [lin]
|
||||
failures = defaultdict(list)
|
||||
failures:DefaultDict[str, List[Tuple[Tuple[LazyOp,...],List[Opt]]]] = defaultdict(list)
|
||||
rawbufs, var_vals, ground_truth = None, None, None
|
||||
|
||||
FUZZ_BEAM = getenv("FUZZ_BEAM", 0)
|
||||
FUZZ_MAX_SIZE = getenv("FUZZ_MAX_SIZE", 0)
|
||||
@@ -81,23 +95,6 @@ def fuzz_linearizer(lin: Linearizer):
|
||||
print("skipping large kernel")
|
||||
return failures
|
||||
|
||||
# get baseline unoptimized output
|
||||
unoptimized = Linearizer(*lin.ast)
|
||||
var_vals = {v: random.randint(v.min, v.max) for v in lin.ast[0].vars()}
|
||||
|
||||
try:
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
print("RAWBUFS FAILED!!")
|
||||
failures["RAWBUFS_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
|
||||
return failures
|
||||
|
||||
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
|
||||
failures["BASELINE_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
|
||||
return failures
|
||||
ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np).copy()
|
||||
|
||||
for depth in range(getenv("DEPTH", 1 if FUZZ_BEAM else 10)):
|
||||
next_lins = []
|
||||
for lin in last_lins:
|
||||
@@ -118,23 +115,15 @@ def fuzz_linearizer(lin: Linearizer):
|
||||
seen_uops[tuops] = tuple(test_lin.applied_opts)
|
||||
|
||||
if not FUZZ_BEAM: print(test_lin.colored_shape())
|
||||
# get a new output buffer
|
||||
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True)
|
||||
if (msg := run_linearizer(test_lin, rawbufs, var_vals)) != "PASS":
|
||||
|
||||
(msg, rawbufs, var_vals, ground_truth) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth)
|
||||
if msg != "PASS":
|
||||
print(test_lin.ast)
|
||||
print(test_lin.applied_opts)
|
||||
print(msg)
|
||||
failures[msg].append((test_lin.ast, test_lin.applied_opts))
|
||||
continue
|
||||
|
||||
result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np)
|
||||
try:
|
||||
# compare memoryviews directly
|
||||
np.testing.assert_allclose(result, ground_truth, rtol=1e-2, atol=1e-2)
|
||||
except AssertionError:
|
||||
print(test_lin.ast)
|
||||
print(test_lin.applied_opts)
|
||||
traceback.print_exc()
|
||||
print("COMPARE FAILED!!")
|
||||
failures["COMPARE_ERROR"].append((test_lin.ast, test_lin.applied_opts))
|
||||
continue
|
||||
next_lins.append(test_lin)
|
||||
|
||||
last_lins = next_lins
|
||||
|
||||
57
test/external/verify_kernel.py
vendored
Normal file
57
test/external/verify_kernel.py
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from extra.optimization.helpers import kern_str_to_lin
|
||||
from test.external.fuzz_linearizer import compare_linearizer
|
||||
from tinygrad.helpers import colored
|
||||
from tinygrad.features.graph import print_tree
|
||||
|
||||
# Use this with the LOGKERN options to verify that all executed kernels are valid and evaluate to the same ground truth results
|
||||
|
||||
# Example for GPT2:
|
||||
# 1) Run the model to log all kernels: `PYTHONPATH=. LOGKERN=/tmp/gpt2_kerns.txt JIT=1 HALF=1 BEAM=2 CACHELEVEL=0 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing` # noqa: E501
|
||||
# 2) Validate the kernel correctness: `PYTHONPATH=. python3 ./test/external/verify_kernel.py --file /tmp/gpt2_kerns.txt`
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Verify the correctness of one or more kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # noqa: E501
|
||||
parser.add_argument("--kernel", type=str, default=None, help="a string of a tuple of (ast, applied_opts,)")
|
||||
parser.add_argument("--file", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
|
||||
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
|
||||
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.kernel is not None:
|
||||
print("loading kernel from args")
|
||||
kern_strs = [args.kernel]
|
||||
elif args.file is not None:
|
||||
print(f"loading kernel from file '{args.file}'")
|
||||
with open(args.file, 'r') as file:
|
||||
kern_strs = file.readlines()
|
||||
else:
|
||||
raise RuntimeError("no kernel specified; use --kernel or --file options")
|
||||
|
||||
print(f"verifying {len(kern_strs)} kernels")
|
||||
|
||||
failed_ids = []
|
||||
failures = defaultdict(list)
|
||||
for i, kern_str in enumerate(kern_strs):
|
||||
print(f"testing kernel {i}")
|
||||
test_lin = kern_str_to_lin(kern_str)
|
||||
for op in test_lin.ast: print_tree(op)
|
||||
print(test_lin.colored_shape())
|
||||
if (msg:=compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)[0]) != "PASS":
|
||||
failed_ids.append(i)
|
||||
failures[msg].append((test_lin.ast, test_lin.applied_opts))
|
||||
|
||||
for msg, errors in failures.items():
|
||||
for i, (ast, opts) in enumerate(errors):
|
||||
print(f"{msg} {i} AST: {ast}")
|
||||
print(f"{msg} {i} OPTS: {opts}\n")
|
||||
|
||||
print(f"tested {len(kern_strs)} kernels")
|
||||
if failures:
|
||||
print(f"{failed_ids=}")
|
||||
for msg, errors in failures.items():
|
||||
print(f"{msg}: {len(errors)}")
|
||||
raise RuntimeError(f"failed on {len(failed_ids)} kernels")
|
||||
else:
|
||||
print(colored("all passed", "green"))
|
||||
Reference in New Issue
Block a user