log optimized kernels and a script to compare with non-optimized ones (#3829)

* search: add BEAM_VERIFY option to validate search results

refactor fuzz_linearizer comparison to allow it to be used in for
BEAM_VERIFY in device.py

* search: fix to verify the beam_search result and not the fastest

* search: fix typing and clean up

* device: remove imports from test and add LOGKERN options

LOGKERN output can be used with test/external/verify_kernel.py
to validate correctness

* fix example in verify_kernel.py

* cleanup fixes

* fix to use f-strings
This commit is contained in:
Francis Lam
2024-03-20 16:22:08 -07:00
committed by GitHub
parent 9d1d08fbb0
commit 6d5dec2fef
4 changed files with 105 additions and 48 deletions

View File

@@ -1,15 +1,17 @@
import random, traceback, ctypes
from typing import List, Tuple
from typing import List, Tuple, DefaultDict
import numpy as np
from collections import defaultdict
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.linearizer import Linearizer, UOp
from tinygrad.codegen.kernel import Opt
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
from tinygrad.tensor import Tensor
from tinygrad.features.graph import print_tree
from tinygrad.helpers import getenv, from_mv, prod, colored, Context
from tinygrad.device import Device, Compiled
from tinygrad.codegen.linearizer import UOp
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import LazyOp
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
@@ -24,7 +26,7 @@ def get_fuzz_rawbufs(lin):
with Context(DEBUG=0):
for rawbuf in rawbufs[1:]:
t = Tensor.uniform((rawbuf.size,), dtype=rawbuf.dtype)
rawbuf.copyin(t.realize().lazydata.realized.as_buffer())
if isinstance(ld:=t.realize().lazydata, LazyBuffer) and ld.realized: rawbuf.copyin(ld.realized.as_buffer())
return rawbufs
def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None):
@@ -38,32 +40,43 @@ def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None):
def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
if rawbufs is None: rawbufs = bufs_from_lin(lin)
if var_vals is None: var_vals = {v: v.min for v in lin.ast.vars()}
if var_vals is None: var_vals = {v: v.min for v in lin.ast[0].vars()}
# TODO: images needs required_optimization
try:
if isinstance(device, Compiled):
prg = device.to_program(lin)
else:
prg = device.get_runner(lin.ast)
prg = device.to_program(lin)
except Exception:
print(lin.ast)
print(lin.applied_opts)
traceback.print_exc()
print("COMPILE FAILED!!")
return "COMPILE_ERROR"
try:
prg.exec(rawbufs, var_vals)
prg(rawbufs, var_vals, wait=True, do_update_stats=False)
except Exception:
print(lin.ast)
print(lin.applied_opts)
traceback.print_exc()
print("EXEC FAILED!!")
return "EXEC_ERROR"
return "PASS"
def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
try:
if rawbufs is None:
rawbufs = get_fuzz_rawbufs(lin)
else:
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
except BaseException:
return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth,)
if var_vals is None: var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()}
if ground_truth is None:
unoptimized = Linearizer(*lin.ast)
unoptimized.required_optimizations()
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,)
ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np).copy()
if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
return (run_msg, rawbufs, var_vals, ground_truth,)
result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np)
return ("PASS" if np.allclose(result, ground_truth, rtol=rtol, atol=atol) else "COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
def fuzz_linearizer(lin: Linearizer):
SEED = getenv("SEED", 42)
@@ -73,7 +86,8 @@ def fuzz_linearizer(lin: Linearizer):
print(lin.colored_shape())
seen_uops = {}
last_lins = [lin]
failures = defaultdict(list)
failures:DefaultDict[str, List[Tuple[Tuple[LazyOp,...],List[Opt]]]] = defaultdict(list)
rawbufs, var_vals, ground_truth = None, None, None
FUZZ_BEAM = getenv("FUZZ_BEAM", 0)
FUZZ_MAX_SIZE = getenv("FUZZ_MAX_SIZE", 0)
@@ -81,23 +95,6 @@ def fuzz_linearizer(lin: Linearizer):
print("skipping large kernel")
return failures
# get baseline unoptimized output
unoptimized = Linearizer(*lin.ast)
var_vals = {v: random.randint(v.min, v.max) for v in lin.ast[0].vars()}
try:
rawbufs = get_fuzz_rawbufs(lin)
except Exception:
traceback.print_exc()
print("RAWBUFS FAILED!!")
failures["RAWBUFS_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
return failures
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
failures["BASELINE_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
return failures
ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np).copy()
for depth in range(getenv("DEPTH", 1 if FUZZ_BEAM else 10)):
next_lins = []
for lin in last_lins:
@@ -118,23 +115,15 @@ def fuzz_linearizer(lin: Linearizer):
seen_uops[tuops] = tuple(test_lin.applied_opts)
if not FUZZ_BEAM: print(test_lin.colored_shape())
# get a new output buffer
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True)
if (msg := run_linearizer(test_lin, rawbufs, var_vals)) != "PASS":
(msg, rawbufs, var_vals, ground_truth) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth)
if msg != "PASS":
print(test_lin.ast)
print(test_lin.applied_opts)
print(msg)
failures[msg].append((test_lin.ast, test_lin.applied_opts))
continue
result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np)
try:
# compare memoryviews directly
np.testing.assert_allclose(result, ground_truth, rtol=1e-2, atol=1e-2)
except AssertionError:
print(test_lin.ast)
print(test_lin.applied_opts)
traceback.print_exc()
print("COMPARE FAILED!!")
failures["COMPARE_ERROR"].append((test_lin.ast, test_lin.applied_opts))
continue
next_lins.append(test_lin)
last_lins = next_lins

57
test/external/verify_kernel.py vendored Normal file
View File

@@ -0,0 +1,57 @@
import argparse
from collections import defaultdict
from extra.optimization.helpers import kern_str_to_lin
from test.external.fuzz_linearizer import compare_linearizer
from tinygrad.helpers import colored
from tinygrad.features.graph import print_tree
# Use this with the LOGKERN options to verify that all executed kernels are valid and evaluate to the same ground truth results
# Example for GPT2:
# 1) Run the model to log all kernels: `PYTHONPATH=. LOGKERN=/tmp/gpt2_kerns.txt JIT=1 HALF=1 BEAM=2 CACHELEVEL=0 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing` # noqa: E501
# 2) Validate the kernel correctness: `PYTHONPATH=. python3 ./test/external/verify_kernel.py --file /tmp/gpt2_kerns.txt`
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Verify the correctness of one or more kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # noqa: E501
parser.add_argument("--kernel", type=str, default=None, help="a string of a tuple of (ast, applied_opts,)")
parser.add_argument("--file", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
args = parser.parse_args()
if args.kernel is not None:
print("loading kernel from args")
kern_strs = [args.kernel]
elif args.file is not None:
print(f"loading kernel from file '{args.file}'")
with open(args.file, 'r') as file:
kern_strs = file.readlines()
else:
raise RuntimeError("no kernel specified; use --kernel or --file options")
print(f"verifying {len(kern_strs)} kernels")
failed_ids = []
failures = defaultdict(list)
for i, kern_str in enumerate(kern_strs):
print(f"testing kernel {i}")
test_lin = kern_str_to_lin(kern_str)
for op in test_lin.ast: print_tree(op)
print(test_lin.colored_shape())
if (msg:=compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)[0]) != "PASS":
failed_ids.append(i)
failures[msg].append((test_lin.ast, test_lin.applied_opts))
for msg, errors in failures.items():
for i, (ast, opts) in enumerate(errors):
print(f"{msg} {i} AST: {ast}")
print(f"{msg} {i} OPTS: {opts}\n")
print(f"tested {len(kern_strs)} kernels")
if failures:
print(f"{failed_ids=}")
for msg, errors in failures.items():
print(f"{msg}: {len(errors)}")
raise RuntimeError(f"failed on {len(failed_ids)} kernels")
else:
print(colored("all passed", "green"))