mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-06 12:44:58 -05:00
reuse fuzz_linearizer.compare_linearizer in test_linearizer_failures (#3861)
This commit is contained in:
16
test/external/fuzz_linearizer.py
vendored
16
test/external/fuzz_linearizer.py
vendored
@@ -25,6 +25,7 @@ def get_fuzz_rawbufs(lin):
|
||||
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True, size=rawbufs[0].size+RED_AREA_SIZE)
|
||||
with Context(DEBUG=0):
|
||||
for rawbuf in rawbufs[1:]:
|
||||
# TODO: better rand values based on dtype
|
||||
t = Tensor.uniform((rawbuf.size,), dtype=rawbuf.dtype)
|
||||
if isinstance(ld:=t.realize().lazydata, LazyBuffer) and ld.realized: rawbuf.copyin(ld.realized.as_buffer())
|
||||
return rawbufs
|
||||
@@ -58,6 +59,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
|
||||
return "PASS"
|
||||
|
||||
def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
|
||||
# TODO: raise specific fuzzing errors instead of str, and propagate the error message
|
||||
try:
|
||||
if rawbufs is None:
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
@@ -65,7 +67,11 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
|
||||
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
|
||||
except BaseException:
|
||||
return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
if var_vals is None: var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()}
|
||||
|
||||
if var_vals is None:
|
||||
# TODO: handle symbolic max case
|
||||
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()}
|
||||
|
||||
if ground_truth is None:
|
||||
unoptimized = Linearizer(*lin.ast)
|
||||
unoptimized.required_optimizations()
|
||||
@@ -77,7 +83,13 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
|
||||
if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
|
||||
return (run_msg, rawbufs, var_vals, ground_truth,)
|
||||
result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np)
|
||||
return ("PASS" if np.allclose(result, ground_truth, rtol=rtol, atol=atol) else "COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
|
||||
try:
|
||||
np.testing.assert_allclose(result, ground_truth, rtol=rtol, atol=atol)
|
||||
except AssertionError:
|
||||
return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
|
||||
return ("PASS", rawbufs, var_vals, ground_truth,)
|
||||
|
||||
def fuzz_linearizer(lin: Linearizer):
|
||||
SEED = getenv("SEED", 42)
|
||||
|
||||
Reference in New Issue
Block a user