mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-05 12:15:05 -05:00
update fuzz_linearizer (#3648)
included non-reduce kernel and kernel with variables. green msg when everything passed it's possible that creating rawbufs failed due to memory error, included that in failure cases
This commit is contained in:
21
test/external/fuzz_linearizer.py
vendored
21
test/external/fuzz_linearizer.py
vendored
@@ -7,7 +7,7 @@ from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.features.graph import print_tree
|
||||
from tinygrad.helpers import getenv, from_mv, prod, Context
|
||||
from tinygrad.helpers import getenv, from_mv, prod, colored, Context
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.codegen.linearizer import UOp
|
||||
|
||||
@@ -70,7 +70,6 @@ def fuzz_linearizer(lin: Linearizer):
|
||||
np.random.seed(42)
|
||||
print_tree(lin.ast)
|
||||
print(lin.colored_shape())
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
seen_uops = {}
|
||||
last_lins = [lin]
|
||||
failures = defaultdict(list)
|
||||
@@ -84,6 +83,15 @@ def fuzz_linearizer(lin: Linearizer):
|
||||
# get baseline unoptimized output
|
||||
unoptimized = Linearizer(lin.ast)
|
||||
var_vals = {v: random.randint(v.min, v.max) for v in lin.ast.vars()}
|
||||
|
||||
try:
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
print("RAWBUFS FAILED!!")
|
||||
failures["RAWBUFS_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
|
||||
return failures
|
||||
|
||||
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
|
||||
failures["BASELINE_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
|
||||
return failures
|
||||
@@ -134,7 +142,7 @@ def fuzz_linearizer(lin: Linearizer):
|
||||
return failures
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds()
|
||||
ast_strs = load_worlds(filter_reduce=False, filter_novariable=False)
|
||||
print(f"{len(ast_strs)=}")
|
||||
tested = 0
|
||||
failures = defaultdict(list)
|
||||
@@ -151,5 +159,8 @@ if __name__ == "__main__":
|
||||
print(f"{msg} {i} AST: {ast}")
|
||||
print(f"{msg} {i} OPTS: {opts}\n")
|
||||
print(f"{tested=}")
|
||||
for msg, errors in failures.items():
|
||||
print(f"{msg}: {len(errors)}")
|
||||
if failures:
|
||||
for msg, errors in failures.items():
|
||||
print(f"{msg}: {len(errors)}")
|
||||
else:
|
||||
print(colored("all passed", "green"))
|
||||
|
||||
Reference in New Issue
Block a user