mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
* squashed fp8 commits * tensorcore start * minor changes * pre-commit * pylint * Delete fp8mul.cu * clean * small bugfix * fix test_dtype * fix test_dtype_alu * add EMULATE_CUDA_SM89 * fix ci * fix test_linearizer * fix test_linearizer * fix swizzle * add debug to simple_matmul * fixed swizzle * python emulator * refactor python emulator * setup fix * numpy setup * ml_dtypes only in emulate_cuda_sm89 * fix pylint * fix tests * fix mypy * fix mypy * fix ruff * done python emulator * add acc type * tests * mypy * clean code * add cuda tensor core tests to CI * minor fix * clean test_dtype.py * clean cstyle.py * clean test_ops.py * fix test * fix test * whitespaces * pylint * pylint * amd? * amd? * amd * reduce lines * mockgpu remove * fix * ruff * ruff * fix mypy * ruff * test only for cuda * fixed formatting * small fixes * small fix * least_upper_dtype if fp8s not supported * log and reciprocal are supported for fp8s * ops python fixes * dtypes.fp8s use * e4m3 + e5m2 result dtype test * truncate linter fix --------- Co-authored-by: pkotzbach <pawkotz@gmail.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> Co-authored-by: chenyu <chenyu@fastmail.com>
48 lines
1.9 KiB
Python
48 lines
1.9 KiB
Python
import numpy as np
|
|
from tinygrad.helpers import getenv
|
|
from tinygrad.dtype import _to_np_dtype
|
|
from tinygrad import dtypes, Tensor
|
|
|
|
dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else
|
|
dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float)
|
|
acc_dtype = (dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else
|
|
dtypes.fp8e4m3 if getenv("ACC_FP8E4M3") else dtypes.fp8e5m2 if getenv("ACC_FP8E5M2") else None)
|
|
if getenv("INT"): dtype_in = dtypes.int8acc_dtype = dtypes.int32
|
|
if getenv("UINT"): dtype_in, acc_dtype = dtypes.uint8, dtypes.int32
|
|
|
|
N = getenv("N", 4096)
|
|
M = getenv("M", N)
|
|
K = getenv("K", N)
|
|
CNT = getenv("CNT", 10)
|
|
ATOL = getenv("ATOL", 1e-4)
|
|
RTOL = getenv("RTOL", 3e-2)
|
|
INT_LOW = getenv("INT_LOW", 0)
|
|
INT_HIGH = getenv("INT_HIGH", 10)
|
|
|
|
if __name__ == "__main__":
|
|
def init_matrix(rows, cols):
|
|
rng = np.random.default_rng()
|
|
# NOTE: numpy does not support bfloat16
|
|
if (np_dtype := _to_np_dtype(dtype_in)) is None: np_dtype = np.float32
|
|
if dtype_in in dtypes.ints:
|
|
return Tensor(rng.integers(INT_LOW, INT_HIGH, (rows, cols), dtype=np_dtype)).realize()
|
|
return Tensor(rng.random((rows, cols), dtype=np.float32).astype(np_dtype)).cast(dtype_in).realize()
|
|
|
|
a, b = init_matrix(M, K), init_matrix(K, N)
|
|
for i in range(CNT):
|
|
if i > 0 and getenv("RAND", 0) != 0:
|
|
a, b = init_matrix(M, K), init_matrix(K, N)
|
|
c = a.matmul(b, dtype=acc_dtype).realize()
|
|
|
|
ref = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
|
res = c.numpy()
|
|
try:
|
|
np.testing.assert_allclose(res, ref, rtol=RTOL, atol=ATOL)
|
|
except AssertionError as e:
|
|
if getenv("DEBUG_VALUES", 0) > 0:
|
|
mismatch = np.where(~np.isclose(res, ref, rtol=RTOL, atol=ATOL))
|
|
print("Mismatch indices:", mismatch)
|
|
print("Result :", res[mismatch])
|
|
print("Ground truth :", ref[mismatch])
|
|
raise e
|