mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
FP8 support on NVIDIA (#8631)
* squashed fp8 commits * tensorcore start * minor changes * pre-commit * pylint * Delete fp8mul.cu * clean * small bugfix * fix test_dtype * fix test_dtype_alu * add EMULATE_CUDA_SM89 * fix ci * fix test_linearizer * fix test_linearizer * fix swizzle * add debug to simple_matmul * fixed swizzle * python emulator * refactor python emulator * setup fix * numpy setup * ml_dtypes only in emulate_cuda_sm89 * fix pylint * fix tests * fix mypy * fix mypy * fix ruff * done python emulator * add acc type * tests * mypy * clean code * add cuda tensor core tests to CI * minor fix * clean test_dtype.py * clean cstyle.py * clean test_ops.py * fix test * fix test * whitespaces * pylint * pylint * amd? * amd? * amd * reduce lines * mockgpu remove * fix * ruff * ruff * fix mypy * ruff * test only for cuda * fixed formatting * small fixes * small fix * least_upper_dtype if fp8s not supported * log and reciprocal are supported for fp8s * ops python fixes * dtypes.fp8s use * e4m3 + e5m2 result dtype test * truncate linter fix --------- Co-authored-by: pkotzbach <pawkotz@gmail.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -3003,6 +3003,17 @@ class TestOpsUint8(unittest.TestCase):
|
||||
lambda x: x.type(torch.uint8).min(),
|
||||
lambda x: x.cast(dtypes.uint8).min(), forward_only=True, vals=[[0, 128, 255, 64, 32, 16]])
|
||||
|
||||
@unittest.skipUnless("CUDA" in Device.get_available_devices() and Device.DEFAULT == "PYTHON" and getenv("EMULATE_CUDA_SM89"),
|
||||
"only for emulated CUDA")
|
||||
class TestOpsFp8s(unittest.TestCase):
|
||||
def _compare_to_cuda(self, shp_a, shp_b, op, dtype):
|
||||
a = Tensor.rand(shp_a, dtype=dtype)
|
||||
b = Tensor.rand(shp_b, dtype=dtype)
|
||||
np.testing.assert_equal(op(a, b).numpy(), op(a.to("CUDA"), b.to("CUDA")).numpy())
|
||||
|
||||
def test_gemm_fp8e4m3(self): self._compare_to_cuda((64, 64), (64, 64), lambda x, y: x.matmul(y), dtypes.fp8e4m3)
|
||||
def test_gemm_fp8e5m2(self): self._compare_to_cuda((64, 64), (64, 64), lambda x, y: x.matmul(y), dtypes.fp8e5m2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user