reuse fuzz_linearizer.compare_linearizer in test_linearizer_failures (#3861)

This commit is contained in:
chenyu
2024-03-21 14:12:27 -04:00
committed by GitHub
parent 33dd99acf4
commit 30fa03243e
2 changed files with 20 additions and 15 deletions

View File

@@ -25,6 +25,7 @@ def get_fuzz_rawbufs(lin):
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True, size=rawbufs[0].size+RED_AREA_SIZE)
with Context(DEBUG=0):
for rawbuf in rawbufs[1:]:
# TODO: better rand values based on dtype
t = Tensor.uniform((rawbuf.size,), dtype=rawbuf.dtype)
if isinstance(ld:=t.realize().lazydata, LazyBuffer) and ld.realized: rawbuf.copyin(ld.realized.as_buffer())
return rawbufs
@@ -58,6 +59,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
return "PASS"
def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
# TODO: raise specific fuzzing errors instead of str, and propagate the error message
try:
if rawbufs is None:
rawbufs = get_fuzz_rawbufs(lin)
@@ -65,7 +67,11 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
except BaseException:
return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth,)
if var_vals is None: var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()}
if var_vals is None:
# TODO: handle symbolic max case
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()}
if ground_truth is None:
unoptimized = Linearizer(*lin.ast)
unoptimized.required_optimizations()
@@ -77,7 +83,13 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
return (run_msg, rawbufs, var_vals, ground_truth,)
result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np)
return ("PASS" if np.allclose(result, ground_truth, rtol=rtol, atol=atol) else "COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
try:
np.testing.assert_allclose(result, ground_truth, rtol=rtol, atol=atol)
except AssertionError:
return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
return ("PASS", rawbufs, var_vals, ground_truth,)
def fuzz_linearizer(lin: Linearizer):
SEED = getenv("SEED", 42)