python float8 support (#11960)

* basic support

* alu

* nan in exec_alu

* rand_for_dtype

* inf + 0.0

* finfo

* revert rand_for_dtype

* clean

* truncate fp8s inf

* spec ok

* float_to_fp8 nan/inf

* least_upper_dtype

* clean up

---------

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
b1tg
2025-09-18 21:17:09 +08:00
committed by GitHub
parent dbbc261075
commit 54c15d74a4
6 changed files with 56 additions and 17 deletions

View File

@@ -4,7 +4,7 @@ import torch
from typing import Any, List
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype, truncate
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad import Device, Tensor, dtypes
from hypothesis import assume, given, settings, strategies as strat
@@ -25,6 +25,7 @@ def get_available_cast_dtypes(dtype: DType) -> List[DType]:
def _to_torch_storage_type(dtype:DType):
if dtype == dtypes.bfloat16: return torch.float32
if dtype in dtypes.fp8s: return torch.float32
return _to_torch_dtype(dtype)
def _test_to_np(a:Tensor, np_dtype, target):
@@ -47,12 +48,15 @@ def _test_cast(a:Tensor, target_dtype:DType):
# TODO: struct.pack cannot pack value > 65504 (max of half) into e format
a = (a > 65504).where(65504, a)
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
expected = list(a.numpy().astype(_to_np_dtype(target_dtype)))
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: truncate[target_dtype](x), expected))
_test_op(lambda: a.cast(target_dtype), target_dtype, expected)
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
if isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize:
raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX")
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype))
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected.tolist())
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype)).tolist()
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: fp8_to_float(x, target_dtype), expected))
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected)
class TestDType(unittest.TestCase):
DTYPE: Any = None
@@ -308,6 +312,8 @@ class TestBitCast(unittest.TestCase):
assume(not (isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2))
if dt2 in dtypes.fp8s:
expected = torch.tensor(list(map(lambda x: fp8_to_float(x, dt2), expected.view(-1).tolist()))).view_as(expected)
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist())
def test_shape_change_bitcast_exceptions(self):
@@ -350,6 +356,9 @@ class TestBoolDType(TestDType): DTYPE = dtypes.bool
class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16
class TestFp8e4m3(TestDType): DTYPE = dtypes.fp8e4m3
class TestFp8e5m2(TestDType): DTYPE = dtypes.fp8e5m2
class TestPtrDType(unittest.TestCase):
def test_vec_double(self):
dt1 = dtypes.float.vec(4).ptr().vec(4)

View File

@@ -1,6 +1,6 @@
import unittest, operator, math
from tinygrad import Tensor, dtypes, Device
from tinygrad.dtype import DType
from tinygrad.dtype import DType, truncate
from tinygrad.helpers import CI, getenv
from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
@@ -8,7 +8,7 @@ from tinygrad.runtime.ops_python import from_storage_scalar
from tinygrad.renderer.ptx import PTXRenderer
import numpy as np
import pytest
from hypothesis import given, strategies as strat, settings, HealthCheck
from hypothesis import assume, given, strategies as strat, settings, HealthCheck
pytestmark = pytest.mark.filterwarnings("ignore")
@@ -48,6 +48,8 @@ class ht:
int64 = strat.integers(-9223372036854775808, 9223372036854775807)
bool = strat.booleans()
ht.bfloat16 = ht.uint16
ht.fp8e4m3 = ht.uint8
ht.fp8e5m2 = ht.uint8
def universal_test(a, b, dtype, op):
if not isinstance(op, tuple): op = (op, op)
@@ -57,8 +59,9 @@ def universal_test(a, b, dtype, op):
ta, tb = Tensor([a], dtype=dtype), Tensor([b], dtype=dtype)
tensor_value = (op[0](ta, tb)).numpy()
numpy_value = op[1](ta.numpy(), tb.numpy())
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value)
if dtype in dtypes.floats:
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2)}.get(dtype, (1e-10, 1e-7))
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype, (1e-10, 1e-7))
np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol)
else: np.testing.assert_equal(tensor_value, numpy_value)
@@ -71,8 +74,10 @@ def universal_test_unary(a, dtype, op):
out: Tensor = op[0](ta)
tensor_value = out.numpy()
numpy_value = op[1](ta.numpy())
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value)
if dtype in dtypes.floats:
atol, rtol = {dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2)}.get(dtype, (1e-6, 1e-5))
atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2),
dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5))
np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol)
else: np.testing.assert_equal(tensor_value, numpy_value)
@@ -111,6 +116,16 @@ class TestDTypeALU(unittest.TestCase):
def test_bfloat16(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))
def test_fp8e4m3(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
@given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations))
def test_fp8e5m2(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op)
@given(ht.float32, strat.sampled_from(unary_operations))
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
@@ -122,6 +137,18 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.bfloat16, strat.sampled_from(unary_operations))
def test_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
@given(ht.fp8e4m3, strat.sampled_from(unary_operations))
def test_fp8e4m3_unary(self, a, op):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0)
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
@given(ht.fp8e5m2, strat.sampled_from(unary_operations))
def test_fp8e5m2_unary(self, a, op):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0)
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op)
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)

View File

@@ -21,7 +21,9 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
if DEBUG >= 2: print(tensor.numpy())
try:
assert tensor.dtype == target_dtype
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, tol_target_dtype))
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2,
dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1}.get(target_dtype, tol_target_dtype))
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
@@ -576,10 +578,10 @@ class TestAutoCastType(unittest.TestCase):
def test_gradient_dtype(self):
old_default_float = dtypes.default_float
for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
for default_dtype in dtypes.floats:
if not is_dtype_supported(default_dtype): continue
dtypes.default_float = default_dtype
for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
for dtype in dtypes.floats:
if not is_dtype_supported(dtype): continue
if DEBUG >= 2:
print(f"testing {default_dtype=}, {dtype=}")

View File

@@ -328,9 +328,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX")
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
return device in {"AMD", "PYTHON"}
if dtype in dtypes.fp8s:
# not supported yet - in progress
return False
if dtype in dtypes.fp8s: return device == "PYTHON"
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
# for CI GPU and OSX, cl_khr_fp16 isn't supported

View File

@@ -322,7 +322,7 @@ truncate: dict[DType, Callable] = {dtypes.bool: bool,
def _to_np_dtype(dtype:DType) -> type|None:
import numpy as np
if dtype == dtypes.bfloat16: return np.float32
if dtype in { dtypes.bfloat16, *dtypes.fp8s }: return np.float32
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
import numpy as np
@@ -333,6 +333,7 @@ def _to_torch_dtype(dtype:DType) -> 'torch.dtype'|None: # type: ignore [name-de
import numpy as np, torch
if dtype == dtypes.uint64: return torch.uint64
if dtype == dtypes.bfloat16: return torch.bfloat16
if dtype in dtypes.fp8s: return torch.uint8
# NOTE: torch doesn't expose this mapping with a stable API
try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype
except TypeError: return None

View File

@@ -4,21 +4,23 @@
# this is the (living) definition of uops
from typing import Any, TYPE_CHECKING, cast
import pickle, base64, itertools, time, struct, sys
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16, float_to_fp8, fp8_to_float
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, EMULATE
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.opt import tc
from tinygrad.uop.ops import exec_alu, python_alu, Ops, UOp, GroupOp
from tinygrad.renderer import Renderer
def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else dtype.fmt
def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else 'B' if dtype in dtypes.fp8s else dtype.fmt
def to_storage_scalar(x, dtype: DType):
if dtype == dtypes.bfloat16: return (struct.unpack('I', struct.pack('f', float_to_bf16(x)))[0] >> 16) & 0xFFFF
if dtype in dtypes.fp8s: return float_to_fp8(float(x), dtype)
return x
def from_storage_scalar(x, dtype: DType):
if dtype == dtypes.bfloat16: return struct.unpack('f', struct.pack('I', (x & 0xFFFF) << 16))[0]
if dtype in dtypes.fp8s: return fp8_to_float(int(x), dtype)
return x
def _load(m, i, dtype: DType):