explicit opts for test_linearizer_failures (#2299)

* explicit opts for test_linearizer_failures

* typo

* update the invalid check
This commit is contained in:
chenyu
2023-11-14 11:52:38 -05:00
committed by GitHub
parent 8916028ddd
commit fac8633ba8
2 changed files with 58 additions and 45 deletions

View File

@@ -1,11 +1,11 @@
import random, traceback
import numpy as np
from collections import Counter, defaultdict
from collections import Counter
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin, tuplize_uops
from tinygrad.graph import print_tree
from tinygrad.helpers import ImageDType, prod, getenv
from tinygrad.helpers import getenv
from tinygrad.ops import Device, Compiled, Interpreted
from tinygrad.lazy import vars_from_ast
@@ -19,6 +19,37 @@ class LB:
self.dtype = dtype
def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
if rawbufs is None: rawbufs = bufs_from_lin(lin)
if var_vals is None: var_vals = {v: v.min for v in vars_from_ast(lin.ast)}
# TODO: images needs required_optimization
if isinstance(device, Compiled):
try:
prg = device.to_program(lin)
except:
print(lin.ast)
traceback.print_exc()
print("COMPILE FAILED!!")
return "COMPILE_ERROR"
try:
prg.exec(rawbufs, var_vals, force_wait=True)
except:
print(lin.ast)
traceback.print_exc()
print("EXEC FAILED!!")
return "EXEC_ERROR"
else:
try:
device.exec_ast(lin.ast, output=LB(rawbufs[0], rawbufs[0].dtype), inputs=[LB(buf, buf.dtype) for buf in rawbufs[1:]])
except Exception as e:
print(lin.ast)
traceback.print_exc()
return str(type(e))
return "PASS"
def fuzz_linearizer(lin: Linearizer):
random.seed(42)
np.random.seed(42)
@@ -44,33 +75,11 @@ def fuzz_linearizer(lin: Linearizer):
# get a new output buffer
rawbufs[0] = type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype)
var_vals = {v: random.randint(v.min, v.max) for v in vars_from_ast(lin.ast)}
# TODO: images needs required_optimization
if isinstance(device, Compiled):
try:
prg = device.to_program(lin)
except:
print(lin.ast)
traceback.print_exc()
print("COMPILE FAILED!!")
return "COMPILE_ERROR"
try:
prg.exec(rawbufs, var_vals, force_wait=True)
except:
print(lin.ast)
traceback.print_exc()
print("EXEC FAILED!!")
return "EXEC_ERROR"
else:
try:
device.exec_ast(lin.ast, output=LB(rawbufs[0], rawbufs[0].dtype), inputs=[LB(buf, buf.dtype) for buf in rawbufs[1:]])
except Exception as e:
print(lin.ast)
traceback.print_exc()
return e
if (msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
print(f"{lin.applied_opts=}")
return msg
result = rawbufs[0].toCPU()
if ground_truth is None:
ground_truth = result
else:
@@ -79,11 +88,8 @@ def fuzz_linearizer(lin: Linearizer):
except AssertionError:
print(lin.ast)
traceback.print_exc()
print(f"{lin.applied_opts=}")
return "NOT_ALLCLOSE"
except Exception as e:
print(lin.ast)
traceback.print_exc()
return e
return "PASS"