mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fuzz_linearizer: add additional DEBUG info for comparison errors (#3866)
This commit is contained in:
11
test/external/fuzz_linearizer.py
vendored
11
test/external/fuzz_linearizer.py
vendored
@@ -9,7 +9,7 @@ from tinygrad.codegen.linearizer import Linearizer, UOp
|
||||
from tinygrad.codegen.kernel import Opt
|
||||
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
|
||||
from tinygrad.features.graph import print_tree
|
||||
from tinygrad.helpers import getenv, from_mv, prod, colored, Context
|
||||
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG
|
||||
from tinygrad.ops import LazyOp
|
||||
|
||||
def tuplize_uops(uops:List[UOp]) -> Tuple:
|
||||
@@ -94,7 +94,14 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
|
||||
|
||||
try:
|
||||
np.testing.assert_allclose(result, ground_truth, rtol=rtol, atol=atol)
|
||||
except AssertionError:
|
||||
except AssertionError as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"COMPARE_ERROR details: {e}")
|
||||
mismatch_indices = np.where(~np.isclose(result, ground_truth, rtol=rtol, atol=atol))
|
||||
mismatched_result = result[mismatch_indices]
|
||||
mismatched_ground_truth = ground_truth[mismatch_indices]
|
||||
for i, idx in enumerate(mismatch_indices[0]):
|
||||
print(f"mismatch at {idx=}: result={mismatched_result[i]} <> ground_truth={mismatched_ground_truth[i]}")
|
||||
return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
|
||||
return ("PASS", rawbufs, var_vals, ground_truth,)
|
||||
|
||||
Reference in New Issue
Block a user