FP8 support on NVIDIA (#8631)

* squashed fp8 commits

* tensorcore start

* minor changes

* pre-commit

* pylint

* Delete fp8mul.cu

* clean

* small bugfix

* fix test_dtype

* fix test_dtype_alu

* add EMULATE_CUDA_SM89

* fix ci

* fix test_linearizer

* fix test_linearizer

* fix swizzle

* add debug to simple_matmul

* fixed swizzle

* python emulator

* refactor python emulator

* setup fix

* numpy setup

* ml_dtypes only in emulate_cuda_sm89

* fix pylint

* fix tests

* fix mypy

* fix mypy

* fix ruff

* done python emulator

* add acc type

* tests

* mypy

* clean code

* add cuda tensor core tests to CI

* minor fix

* clean test_dtype.py

* clean cstyle.py

* clean test_ops.py

* fix test

* fix test

* whitespaces

* pylint

* pylint

* amd?

* amd?

* amd

* reduce lines

* mockgpu remove

* fix

* ruff

* ruff

* fix mypy

* ruff

* test only for cuda

* fixed formatting

* small fixes

* small fix

* least_upper_dtype if fp8s not supported

* log and reciprocal are supported for fp8s

* ops python fixes

* dtypes.fp8s use

* e4m3 + e5m2 result dtype test

* truncate linter fix

---------

Co-authored-by: pkotzbach <pawkotz@gmail.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
pkotzbach
2025-04-09 03:54:04 +02:00
committed by GitHub
parent 5d85765327
commit 2c8e4ea865
13 changed files with 203 additions and 45 deletions

View File

@@ -3003,6 +3003,17 @@ class TestOpsUint8(unittest.TestCase):
lambda x: x.type(torch.uint8).min(),
lambda x: x.cast(dtypes.uint8).min(), forward_only=True, vals=[[0, 128, 255, 64, 32, 16]])
@unittest.skipUnless("CUDA" in Device.get_available_devices() and Device.DEFAULT == "PYTHON" and getenv("EMULATE_CUDA_SM89"),
"only for emulated CUDA")
class TestOpsFp8s(unittest.TestCase):
def _compare_to_cuda(self, shp_a, shp_b, op, dtype):
a = Tensor.rand(shp_a, dtype=dtype)
b = Tensor.rand(shp_b, dtype=dtype)
np.testing.assert_equal(op(a, b).numpy(), op(a.to("CUDA"), b.to("CUDA")).numpy())
def test_gemm_fp8e4m3(self): self._compare_to_cuda((64, 64), (64, 64), lambda x, y: x.matmul(y), dtypes.fp8e4m3)
def test_gemm_fp8e5m2(self): self._compare_to_cuda((64, 64), (64, 64), lambda x, y: x.matmul(y), dtypes.fp8e5m2)
if __name__ == '__main__':
np.random.seed(1337)
unittest.main(verbosity=2)