python bfloat16 (#11912)

* python bf16

* _to_torch_storage_type

---------

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
b1tg
2025-08-30 03:18:02 +08:00
committed by GitHub
parent afad7d0cd1
commit b2cc06218a
4 changed files with 48 additions and 27 deletions

View File

@@ -4,9 +4,8 @@ import torch
from typing import Any, List
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from hypothesis import assume, given, settings, strategies as strat
from test.helpers import rand_for_dtype
from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX
@@ -24,6 +23,10 @@ def get_available_cast_dtypes(dtype: DType) -> List[DType]:
# dont cast internal dtypes
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")]
def _to_torch_storage_type(dtype:DType):
if dtype == dtypes.bfloat16: return torch.float32
return _to_torch_dtype(dtype)
def _test_to_np(a:Tensor, np_dtype, target):
if DEBUG >= 2: print(a)
na = a.numpy()
@@ -46,10 +49,10 @@ def _test_cast(a:Tensor, target_dtype:DType):
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
if getenv("PTX") and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize:
raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX")
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist())
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype))
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected.tolist())
class TestDType(unittest.TestCase):
DTYPE: Any = None
@@ -126,7 +129,7 @@ class TestDType(unittest.TestCase):
def test_finfo(self):
if self.DTYPE not in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: return
info = np.finfo(_to_np_dtype(self.DTYPE))
info = ml_dtypes.finfo(ml_dtypes.bfloat16 if self.DTYPE is dtypes.bfloat16 else _to_np_dtype(self.DTYPE))
assert info.bits == self.DTYPE.itemsize*8
assert info.nexp == dtypes.finfo(self.DTYPE)[0]
assert info.nmant == dtypes.finfo(self.DTYPE)[1]
@@ -299,10 +302,10 @@ class TestBitCast(unittest.TestCase):
@given(strat.sampled_from(dtype_ints + dtype_floats), strat.sampled_from(dtype_ints + dtype_floats))
def test_shape_change_bitcast(self, dt1, dt2):
# NOTE: this has to be assume to prevent hypothesis from skipping all samples
assume(dt2 != dtypes.bfloat16 and dt1 != dtypes.bfloat16) # no test for bf16 bitcast yet
assume(not (getenv("PTX") and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, data.view(_to_np_dtype(dt2)).tolist())
expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2))
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist())
def test_shape_change_bitcast_exceptions(self):
with self.assertRaises(RuntimeError):
@@ -342,6 +345,9 @@ class TestUint64DType(TestDType):
class TestBoolDType(TestDType): DTYPE = dtypes.bool
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16
class TestPtrDType(unittest.TestCase):
def test_vec_double(self):
dt1 = dtypes.float.vec(4).ptr().vec(4)
@@ -414,7 +420,7 @@ class TestDtypeUsage(unittest.TestCase):
t = Tensor([[1, 2], [3, 4]], dtype=d)
(t*t).max().item()
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16) or Device.DEFAULT == "PYTHON", f"no bfloat16 on {Device.DEFAULT}")
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
class TestOpsBFloat16(unittest.TestCase):
def test_cast(self):
# TODO: helper_test_op breaks in unrelated part