mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
python float8 support (#11960)
* basic support * alu * nan in exec_alu * rand_for_dtype * inf + 0.0 * finfo * revert rand_for_dtype * clean * truncate fp8s inf * spec ok * float_to_fp8 nan/inf * least_upper_dtype * clean up --------- Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import torch
|
||||
from typing import Any, List
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, DEBUG, CI
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype, truncate
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from hypothesis import assume, given, settings, strategies as strat
|
||||
@@ -25,6 +25,7 @@ def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
|
||||
def _to_torch_storage_type(dtype:DType):
|
||||
if dtype == dtypes.bfloat16: return torch.float32
|
||||
if dtype in dtypes.fp8s: return torch.float32
|
||||
return _to_torch_dtype(dtype)
|
||||
|
||||
def _test_to_np(a:Tensor, np_dtype, target):
|
||||
@@ -47,12 +48,15 @@ def _test_cast(a:Tensor, target_dtype:DType):
|
||||
# TODO: struct.pack cannot pack value > 65504 (max of half) into e format
|
||||
a = (a > 65504).where(65504, a)
|
||||
|
||||
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
|
||||
expected = list(a.numpy().astype(_to_np_dtype(target_dtype)))
|
||||
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: truncate[target_dtype](x), expected))
|
||||
_test_op(lambda: a.cast(target_dtype), target_dtype, expected)
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
|
||||
if isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize:
|
||||
raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX")
|
||||
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype))
|
||||
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected.tolist())
|
||||
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype)).tolist()
|
||||
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: fp8_to_float(x, target_dtype), expected))
|
||||
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected)
|
||||
|
||||
class TestDType(unittest.TestCase):
|
||||
DTYPE: Any = None
|
||||
@@ -308,6 +312,8 @@ class TestBitCast(unittest.TestCase):
|
||||
assume(not (isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX
|
||||
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
|
||||
expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2))
|
||||
if dt2 in dtypes.fp8s:
|
||||
expected = torch.tensor(list(map(lambda x: fp8_to_float(x, dt2), expected.view(-1).tolist()))).view_as(expected)
|
||||
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist())
|
||||
|
||||
def test_shape_change_bitcast_exceptions(self):
|
||||
@@ -350,6 +356,9 @@ class TestBoolDType(TestDType): DTYPE = dtypes.bool
|
||||
|
||||
class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16
|
||||
|
||||
class TestFp8e4m3(TestDType): DTYPE = dtypes.fp8e4m3
|
||||
class TestFp8e5m2(TestDType): DTYPE = dtypes.fp8e5m2
|
||||
|
||||
class TestPtrDType(unittest.TestCase):
|
||||
def test_vec_double(self):
|
||||
dt1 = dtypes.float.vec(4).ptr().vec(4)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest, operator, math
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.dtype import DType, truncate
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -8,7 +8,7 @@ from tinygrad.runtime.ops_python import from_storage_scalar
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
import numpy as np
|
||||
import pytest
|
||||
from hypothesis import given, strategies as strat, settings, HealthCheck
|
||||
from hypothesis import assume, given, strategies as strat, settings, HealthCheck
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings("ignore")
|
||||
|
||||
@@ -48,6 +48,8 @@ class ht:
|
||||
int64 = strat.integers(-9223372036854775808, 9223372036854775807)
|
||||
bool = strat.booleans()
|
||||
ht.bfloat16 = ht.uint16
|
||||
ht.fp8e4m3 = ht.uint8
|
||||
ht.fp8e5m2 = ht.uint8
|
||||
|
||||
def universal_test(a, b, dtype, op):
|
||||
if not isinstance(op, tuple): op = (op, op)
|
||||
@@ -57,8 +59,9 @@ def universal_test(a, b, dtype, op):
|
||||
ta, tb = Tensor([a], dtype=dtype), Tensor([b], dtype=dtype)
|
||||
tensor_value = (op[0](ta, tb)).numpy()
|
||||
numpy_value = op[1](ta.numpy(), tb.numpy())
|
||||
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value)
|
||||
if dtype in dtypes.floats:
|
||||
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2)}.get(dtype, (1e-10, 1e-7))
|
||||
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype, (1e-10, 1e-7))
|
||||
np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol)
|
||||
else: np.testing.assert_equal(tensor_value, numpy_value)
|
||||
|
||||
@@ -71,8 +74,10 @@ def universal_test_unary(a, dtype, op):
|
||||
out: Tensor = op[0](ta)
|
||||
tensor_value = out.numpy()
|
||||
numpy_value = op[1](ta.numpy())
|
||||
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value)
|
||||
if dtype in dtypes.floats:
|
||||
atol, rtol = {dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2)}.get(dtype, (1e-6, 1e-5))
|
||||
atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2),
|
||||
dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5))
|
||||
np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol)
|
||||
else: np.testing.assert_equal(tensor_value, numpy_value)
|
||||
|
||||
@@ -111,6 +116,16 @@ class TestDTypeALU(unittest.TestCase):
|
||||
def test_bfloat16(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))
|
||||
def test_fp8e4m3(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations))
|
||||
def test_fp8e5m2(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
||||
@given(ht.float32, strat.sampled_from(unary_operations))
|
||||
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
|
||||
|
||||
@@ -122,6 +137,18 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.bfloat16, strat.sampled_from(unary_operations))
|
||||
def test_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e4m3, strat.sampled_from(unary_operations))
|
||||
def test_fp8e4m3_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e5m2, strat.sampled_from(unary_operations))
|
||||
def test_fp8e5m2_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
||||
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)
|
||||
|
||||
|
||||
@@ -21,7 +21,9 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
|
||||
if DEBUG >= 2: print(tensor.numpy())
|
||||
try:
|
||||
assert tensor.dtype == target_dtype
|
||||
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, tol_target_dtype))
|
||||
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2,
|
||||
dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1}.get(target_dtype, tol_target_dtype))
|
||||
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
||||
|
||||
@@ -576,10 +578,10 @@ class TestAutoCastType(unittest.TestCase):
|
||||
def test_gradient_dtype(self):
|
||||
old_default_float = dtypes.default_float
|
||||
|
||||
for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
for default_dtype in dtypes.floats:
|
||||
if not is_dtype_supported(default_dtype): continue
|
||||
dtypes.default_float = default_dtype
|
||||
for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
for dtype in dtypes.floats:
|
||||
if not is_dtype_supported(dtype): continue
|
||||
if DEBUG >= 2:
|
||||
print(f"testing {default_dtype=}, {dtype=}")
|
||||
|
||||
@@ -328,9 +328,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX")
|
||||
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
|
||||
return device in {"AMD", "PYTHON"}
|
||||
if dtype in dtypes.fp8s:
|
||||
# not supported yet - in progress
|
||||
return False
|
||||
if dtype in dtypes.fp8s: return device == "PYTHON"
|
||||
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
|
||||
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
|
||||
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
||||
|
||||
@@ -322,7 +322,7 @@ truncate: dict[DType, Callable] = {dtypes.bool: bool,
|
||||
|
||||
def _to_np_dtype(dtype:DType) -> type|None:
|
||||
import numpy as np
|
||||
if dtype == dtypes.bfloat16: return np.float32
|
||||
if dtype in { dtypes.bfloat16, *dtypes.fp8s }: return np.float32
|
||||
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
||||
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
||||
import numpy as np
|
||||
@@ -333,6 +333,7 @@ def _to_torch_dtype(dtype:DType) -> 'torch.dtype'|None: # type: ignore [name-de
|
||||
import numpy as np, torch
|
||||
if dtype == dtypes.uint64: return torch.uint64
|
||||
if dtype == dtypes.bfloat16: return torch.bfloat16
|
||||
if dtype in dtypes.fp8s: return torch.uint8
|
||||
# NOTE: torch doesn't expose this mapping with a stable API
|
||||
try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype
|
||||
except TypeError: return None
|
||||
|
||||
@@ -4,21 +4,23 @@
|
||||
# this is the (living) definition of uops
|
||||
from typing import Any, TYPE_CHECKING, cast
|
||||
import pickle, base64, itertools, time, struct, sys
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16, float_to_fp8, fp8_to_float
|
||||
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, EMULATE
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.uop.ops import exec_alu, python_alu, Ops, UOp, GroupOp
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else dtype.fmt
|
||||
def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else 'B' if dtype in dtypes.fp8s else dtype.fmt
|
||||
|
||||
def to_storage_scalar(x, dtype: DType):
|
||||
if dtype == dtypes.bfloat16: return (struct.unpack('I', struct.pack('f', float_to_bf16(x)))[0] >> 16) & 0xFFFF
|
||||
if dtype in dtypes.fp8s: return float_to_fp8(float(x), dtype)
|
||||
return x
|
||||
|
||||
def from_storage_scalar(x, dtype: DType):
|
||||
if dtype == dtypes.bfloat16: return struct.unpack('f', struct.pack('I', (x & 0xFFFF) << 16))[0]
|
||||
if dtype in dtypes.fp8s: return fp8_to_float(int(x), dtype)
|
||||
return x
|
||||
|
||||
def _load(m, i, dtype: DType):
|
||||
|
||||
Reference in New Issue
Block a user