mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
python bfloat16 (#11912)
* python bf16 * _to_torch_storage_type --------- Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
@@ -4,9 +4,8 @@ import torch
|
||||
from typing import Any, List
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, DEBUG, CI
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from hypothesis import assume, given, settings, strategies as strat
|
||||
from test.helpers import rand_for_dtype
|
||||
from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX
|
||||
@@ -24,6 +23,10 @@ def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
# dont cast internal dtypes
|
||||
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")]
|
||||
|
||||
def _to_torch_storage_type(dtype:DType):
|
||||
if dtype == dtypes.bfloat16: return torch.float32
|
||||
return _to_torch_dtype(dtype)
|
||||
|
||||
def _test_to_np(a:Tensor, np_dtype, target):
|
||||
if DEBUG >= 2: print(a)
|
||||
na = a.numpy()
|
||||
@@ -46,10 +49,10 @@ def _test_cast(a:Tensor, target_dtype:DType):
|
||||
|
||||
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
|
||||
if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
|
||||
if getenv("PTX") and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize:
|
||||
raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX")
|
||||
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist())
|
||||
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype))
|
||||
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected.tolist())
|
||||
|
||||
class TestDType(unittest.TestCase):
|
||||
DTYPE: Any = None
|
||||
@@ -126,7 +129,7 @@ class TestDType(unittest.TestCase):
|
||||
|
||||
def test_finfo(self):
|
||||
if self.DTYPE not in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: return
|
||||
info = np.finfo(_to_np_dtype(self.DTYPE))
|
||||
info = ml_dtypes.finfo(ml_dtypes.bfloat16 if self.DTYPE is dtypes.bfloat16 else _to_np_dtype(self.DTYPE))
|
||||
assert info.bits == self.DTYPE.itemsize*8
|
||||
assert info.nexp == dtypes.finfo(self.DTYPE)[0]
|
||||
assert info.nmant == dtypes.finfo(self.DTYPE)[1]
|
||||
@@ -299,10 +302,10 @@ class TestBitCast(unittest.TestCase):
|
||||
@given(strat.sampled_from(dtype_ints + dtype_floats), strat.sampled_from(dtype_ints + dtype_floats))
|
||||
def test_shape_change_bitcast(self, dt1, dt2):
|
||||
# NOTE: this has to be assume to prevent hypothesis from skipping all samples
|
||||
assume(dt2 != dtypes.bfloat16 and dt1 != dtypes.bfloat16) # no test for bf16 bitcast yet
|
||||
assume(not (getenv("PTX") and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX
|
||||
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
|
||||
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, data.view(_to_np_dtype(dt2)).tolist())
|
||||
expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2))
|
||||
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist())
|
||||
|
||||
def test_shape_change_bitcast_exceptions(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
@@ -342,6 +345,9 @@ class TestUint64DType(TestDType):
|
||||
|
||||
class TestBoolDType(TestDType): DTYPE = dtypes.bool
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
|
||||
class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16
|
||||
|
||||
class TestPtrDType(unittest.TestCase):
|
||||
def test_vec_double(self):
|
||||
dt1 = dtypes.float.vec(4).ptr().vec(4)
|
||||
@@ -414,7 +420,7 @@ class TestDtypeUsage(unittest.TestCase):
|
||||
t = Tensor([[1, 2], [3, 4]], dtype=d)
|
||||
(t*t).max().item()
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16) or Device.DEFAULT == "PYTHON", f"no bfloat16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
|
||||
class TestOpsBFloat16(unittest.TestCase):
|
||||
def test_cast(self):
|
||||
# TODO: helper_test_op breaks in unrelated part
|
||||
|
||||
Reference in New Issue
Block a user