fuzz_linearizer: add additional DEBUG info for comparison errors (#3866)

This commit is contained in:
Francis Lam
2024-03-21 15:58:10 -07:00
committed by GitHub
parent bc482729d0
commit 3c0478bfab

View File

@@ -9,7 +9,7 @@ from tinygrad.codegen.linearizer import Linearizer, UOp
from tinygrad.codegen.kernel import Opt
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
from tinygrad.features.graph import print_tree
from tinygrad.helpers import getenv, from_mv, prod, colored, Context
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG
from tinygrad.ops import LazyOp
def tuplize_uops(uops:List[UOp]) -> Tuple:
@@ -94,7 +94,14 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
try:
np.testing.assert_allclose(result, ground_truth, rtol=rtol, atol=atol)
except AssertionError:
except AssertionError as e:
if DEBUG >= 2:
print(f"COMPARE_ERROR details: {e}")
mismatch_indices = np.where(~np.isclose(result, ground_truth, rtol=rtol, atol=atol))
mismatched_result = result[mismatch_indices]
mismatched_ground_truth = ground_truth[mismatch_indices]
for i, idx in enumerate(mismatch_indices[0]):
print(f"mismatch at {idx=}: result={mismatched_result[i]} <> ground_truth={mismatched_ground_truth[i]}")
return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
return ("PASS", rawbufs, var_vals, ground_truth,)