mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
test: add fuzz_matmul and better debugging for simple_matmul (#4199)
also show unoptimized shape in verify_kernel
This commit is contained in:
42
extra/gemm/fuzz_matmul.py
Normal file
42
extra/gemm/fuzz_matmul.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import numpy as np
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad import dtypes, Tensor
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
||||
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
|
||||
N_START = getenv("N_START", 1)
|
||||
M_START = getenv("M_START", 1)
|
||||
K_START = getenv("K_START", 1)
|
||||
N_STOP = getenv("N_STOP", 32)
|
||||
M_STOP = getenv("M_STOP", N_STOP)
|
||||
K_STOP = getenv("K_STOP", N_STOP)
|
||||
N_STEP = getenv("N_STEP", 1)
|
||||
M_STEP = getenv("M_STEP", 1)
|
||||
K_STEP = getenv("K_STEP", 1)
|
||||
ATOL = getenv("ATOL", 1e-4)
|
||||
RTOL = getenv("RTOL", 3e-2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
failed = []
|
||||
for M in range(M_START, M_STOP+1, M_STEP):
|
||||
for N in range(N_START, N_STOP+1, N_STEP):
|
||||
for K in range(K_START, K_STOP+1, K_STEP):
|
||||
print(f"testing {M=} {N=} {K=}")
|
||||
a, b = Tensor.rand(M, K, dtype=dtype_in).realize(), Tensor.rand(K, N, dtype=dtype_in).realize()
|
||||
c = a.matmul(b, acc_dtype=acc_dtype).realize()
|
||||
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
nc = c.numpy()
|
||||
try:
|
||||
np.testing.assert_allclose(nc, comp, atol=ATOL, rtol=RTOL)
|
||||
except AssertionError as e:
|
||||
failed.append((M,N,K,))
|
||||
if getenv("DEBUG_VALUES") > 0:
|
||||
indices = np.where(~np.isclose(nc, comp, rtol=RTOL, atol=ATOL))
|
||||
non_matching_elements_nc = nc[indices]
|
||||
non_matching_elements_comp = comp[indices]
|
||||
print(indices)
|
||||
print("result :", non_matching_elements_nc)
|
||||
print("ground truth:", non_matching_elements_comp)
|
||||
print(e)
|
||||
pass
|
||||
print(f"failed sizes: {failed}")
|
||||
print(f"num failures: {len(failed)}")
|
||||
@@ -4,16 +4,28 @@ from tinygrad import dtypes, Tensor
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
||||
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
|
||||
N = getenv("N", 4096)
|
||||
M = getenv("M", N)
|
||||
K = getenv("K", N)
|
||||
CNT = getenv("CNT", 10)
|
||||
ATOL = getenv("ATOL", 1e-4)
|
||||
RTOL = getenv("RTOL", 3e-2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
||||
a, b = Tensor.rand(M, K, dtype=dtype_in).realize(), Tensor.rand(K, N, dtype=dtype_in).realize()
|
||||
for i in range(CNT):
|
||||
if i > 0 and getenv("RAND", 0) != 0:
|
||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
||||
a, b = Tensor.rand(M, K, dtype=dtype_in).realize(), Tensor.rand(K, N, dtype=dtype_in).realize()
|
||||
c = a.matmul(b, acc_dtype=acc_dtype).realize()
|
||||
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
nc = c.numpy()
|
||||
np.testing.assert_allclose(nc, comp, atol=ATOL, rtol=RTOL)
|
||||
try:
|
||||
np.testing.assert_allclose(nc, comp, atol=ATOL, rtol=RTOL)
|
||||
except AssertionError as e:
|
||||
if getenv("DEBUG_VALUES") > 0:
|
||||
indices = np.where(~np.isclose(nc, comp, rtol=RTOL, atol=ATOL))
|
||||
non_matching_elements_nc = nc[indices]
|
||||
non_matching_elements_comp = comp[indices]
|
||||
print(indices)
|
||||
print("result :", non_matching_elements_nc)
|
||||
print("ground truth:", non_matching_elements_comp)
|
||||
raise e
|
||||
|
||||
5
test/external/verify_kernel.py
vendored
5
test/external/verify_kernel.py
vendored
@@ -3,6 +3,7 @@ from collections import defaultdict
|
||||
from extra.optimization.helpers import kern_str_to_lin
|
||||
from test.external.fuzz_linearizer import compare_linearizer
|
||||
from tinygrad.helpers import colored
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.graph import print_tree
|
||||
from tinygrad.features.search import time_linearizer
|
||||
|
||||
@@ -43,7 +44,9 @@ if __name__ == "__main__":
|
||||
print_tree(op)
|
||||
print(op)
|
||||
print(test_lin.applied_opts)
|
||||
print(test_lin.colored_shape())
|
||||
unoptimized_lin = Linearizer(*test_lin.ast)
|
||||
unoptimized_lin.required_optimizations()
|
||||
print(f"{unoptimized_lin.colored_shape()} -> {test_lin.colored_shape()}")
|
||||
(msg,rb,vv,gt) = compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)
|
||||
if msg != "PASS":
|
||||
failed_ids.append(i)
|
||||
|
||||
Reference in New Issue
Block a user