Files
tinygrad/extra/gemm/simple_matmul.py
pkotzbach 2c8e4ea865 FP8 support on NVIDIA (#8631)
* squashed fp8 commits

* tensorcore start

* minor changes

* pre-commit

* pylint

* Delete fp8mul.cu

* clean

* small bugfix

* fix test_dtype

* fix test_dtype_alu

* add EMULATE_CUDA_SM89

* fix ci

* fix test_linearizer

* fix test_linearizer

* fix swizzle

* add debug to simple_matmul

* fixed swizzle

* python emulator

* refactor python emulator

* setup fix

* numpy setup

* ml_dtypes only in emulate_cuda_sm89

* fix pylint

* fix tests

* fix mypy

* fix mypy

* fix ruff

* done python emulator

* add acc type

* tests

* mypy

* clean code

* add cuda tensor core tests to CI

* minor fix

* clean test_dtype.py

* clean cstyle.py

* clean test_ops.py

* fix test

* fix test

* whitespaces

* pylint

* pylint

* amd?

* amd?

* amd

* reduce lines

* mockgpu remove

* fix

* ruff

* ruff

* fix mypy

* ruff

* test only for cuda

* fixed formatting

* small fixes

* small fix

* least_upper_dtype if fp8s not supported

* log and reciprocal are supported for fp8s

* ops python fixes

* dtypes.fp8s use

* e4m3 + e5m2 result dtype test

* truncate linter fix

---------

Co-authored-by: pkotzbach <pawkotz@gmail.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
2025-04-08 21:54:04 -04:00

48 lines
1.9 KiB
Python

import numpy as np
from tinygrad.helpers import getenv
from tinygrad.dtype import _to_np_dtype
from tinygrad import dtypes, Tensor
dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else
dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float)
acc_dtype = (dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else
dtypes.fp8e4m3 if getenv("ACC_FP8E4M3") else dtypes.fp8e5m2 if getenv("ACC_FP8E5M2") else None)
if getenv("INT"): dtype_in = dtypes.int8acc_dtype = dtypes.int32
if getenv("UINT"): dtype_in, acc_dtype = dtypes.uint8, dtypes.int32
N = getenv("N", 4096)
M = getenv("M", N)
K = getenv("K", N)
CNT = getenv("CNT", 10)
ATOL = getenv("ATOL", 1e-4)
RTOL = getenv("RTOL", 3e-2)
INT_LOW = getenv("INT_LOW", 0)
INT_HIGH = getenv("INT_HIGH", 10)
if __name__ == "__main__":
def init_matrix(rows, cols):
rng = np.random.default_rng()
# NOTE: numpy does not support bfloat16
if (np_dtype := _to_np_dtype(dtype_in)) is None: np_dtype = np.float32
if dtype_in in dtypes.ints:
return Tensor(rng.integers(INT_LOW, INT_HIGH, (rows, cols), dtype=np_dtype)).realize()
return Tensor(rng.random((rows, cols), dtype=np.float32).astype(np_dtype)).cast(dtype_in).realize()
a, b = init_matrix(M, K), init_matrix(K, N)
for i in range(CNT):
if i > 0 and getenv("RAND", 0) != 0:
a, b = init_matrix(M, K), init_matrix(K, N)
c = a.matmul(b, dtype=acc_dtype).realize()
ref = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
res = c.numpy()
try:
np.testing.assert_allclose(res, ref, rtol=RTOL, atol=ATOL)
except AssertionError as e:
if getenv("DEBUG_VALUES", 0) > 0:
mismatch = np.where(~np.isclose(res, ref, rtol=RTOL, atol=ATOL))
print("Mismatch indices:", mismatch)
print("Result :", res[mismatch])
print("Ground truth :", ref[mismatch])
raise e