mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
reuse fuzz_linearizer.compare_linearizer in test_linearizer_failures (#3861)
This commit is contained in:
16
test/external/fuzz_linearizer.py
vendored
16
test/external/fuzz_linearizer.py
vendored
@@ -25,6 +25,7 @@ def get_fuzz_rawbufs(lin):
|
||||
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True, size=rawbufs[0].size+RED_AREA_SIZE)
|
||||
with Context(DEBUG=0):
|
||||
for rawbuf in rawbufs[1:]:
|
||||
# TODO: better rand values based on dtype
|
||||
t = Tensor.uniform((rawbuf.size,), dtype=rawbuf.dtype)
|
||||
if isinstance(ld:=t.realize().lazydata, LazyBuffer) and ld.realized: rawbuf.copyin(ld.realized.as_buffer())
|
||||
return rawbufs
|
||||
@@ -58,6 +59,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
|
||||
return "PASS"
|
||||
|
||||
def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
|
||||
# TODO: raise specific fuzzing errors instead of str, and propagate the error message
|
||||
try:
|
||||
if rawbufs is None:
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
@@ -65,7 +67,11 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
|
||||
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
|
||||
except BaseException:
|
||||
return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
if var_vals is None: var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()}
|
||||
|
||||
if var_vals is None:
|
||||
# TODO: handle symbolic max case
|
||||
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()}
|
||||
|
||||
if ground_truth is None:
|
||||
unoptimized = Linearizer(*lin.ast)
|
||||
unoptimized.required_optimizations()
|
||||
@@ -77,7 +83,13 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
|
||||
if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
|
||||
return (run_msg, rawbufs, var_vals, ground_truth,)
|
||||
result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np)
|
||||
return ("PASS" if np.allclose(result, ground_truth, rtol=rtol, atol=atol) else "COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
|
||||
try:
|
||||
np.testing.assert_allclose(result, ground_truth, rtol=rtol, atol=atol)
|
||||
except AssertionError:
|
||||
return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
|
||||
return ("PASS", rawbufs, var_vals, ground_truth,)
|
||||
|
||||
def fuzz_linearizer(lin: Linearizer):
|
||||
SEED = getenv("SEED", 42)
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.search import Opt, OptOps
|
||||
from tinygrad import Device, dtypes, Tensor
|
||||
from tinygrad.helpers import CI
|
||||
from test.external.fuzz_linearizer import run_linearizer, get_fuzz_rawbufs, get_fuzz_rawbuf_like
|
||||
from test.external.fuzz_linearizer import compare_linearizer
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
|
||||
@@ -17,12 +17,6 @@ def helper_test_lin(lin: Linearizer, opts, failed_platforms):
|
||||
if any(b.dtype == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half):
|
||||
return
|
||||
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
var_vals = {v: random.randint(v.min, v.max) for v in lin.ast[0].vars()}
|
||||
|
||||
assert run_linearizer(lin, rawbufs, var_vals) == "PASS" or Device.DEFAULT in failed_platforms, "Failed running non-optimized ast"
|
||||
ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np)
|
||||
|
||||
for opt in opts:
|
||||
try:
|
||||
lin.apply_opt(opt)
|
||||
@@ -31,13 +25,12 @@ def helper_test_lin(lin: Linearizer, opts, failed_platforms):
|
||||
assert Device.DEFAULT not in failed_platforms
|
||||
return
|
||||
|
||||
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True)
|
||||
linearizer_passed = (run_linearizer(lin, rawbufs, var_vals) == "PASS")
|
||||
output_allclose = np.allclose(ground_truth, np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np), rtol=1e-2, atol=1e-2)
|
||||
if Device.DEFAULT not in failed_platforms:
|
||||
assert linearizer_passed and output_allclose, f"{linearizer_passed=}, {output_allclose=}"
|
||||
compare_result = compare_linearizer(lin)
|
||||
if compare_result[0] in ["PASS", "KernelOptError"]:
|
||||
# it's considered fixed if we invalidated the opts
|
||||
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
|
||||
else:
|
||||
assert not linearizer_passed or not output_allclose, f"{linearizer_passed=}, {output_allclose=}"
|
||||
assert Device.DEFAULT in failed_platforms, f"failed on {Device.DEFAULT} with {compare_result[0]}"
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT=="CUDA", "failed on CUDA CI")
|
||||
class TestLinearizerFailures(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user