diff --git a/test/test_dtype.py b/test/test_dtype.py index a83edb95f7..69b9704781 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -4,9 +4,8 @@ import torch from typing import Any, List from tinygrad.device import is_dtype_supported from tinygrad.helpers import getenv, DEBUG, CI -from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8 +from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype from tinygrad import Device, Tensor, dtypes -from tinygrad.tensor import _to_np_dtype from hypothesis import assume, given, settings, strategies as strat from test.helpers import rand_for_dtype from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX @@ -24,6 +23,10 @@ def get_available_cast_dtypes(dtype: DType) -> List[DType]: # dont cast internal dtypes return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] +def _to_torch_storage_type(dtype:DType): + if dtype == dtypes.bfloat16: return torch.float32 + return _to_torch_dtype(dtype) + def _test_to_np(a:Tensor, np_dtype, target): if DEBUG >= 2: print(a) na = a.numpy() @@ -46,10 +49,10 @@ def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype)))) def _test_bitcast(a:Tensor, target_dtype:DType, target=None): - if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet") if getenv("PTX") and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize: raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX") - _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist()) + expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype)) + _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected.tolist()) class TestDType(unittest.TestCase): DTYPE: Any = None @@ -126,7 +129,7 @@ class TestDType(unittest.TestCase): def test_finfo(self): if self.DTYPE not in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: return - info = np.finfo(_to_np_dtype(self.DTYPE)) + info = ml_dtypes.finfo(ml_dtypes.bfloat16 if self.DTYPE is dtypes.bfloat16 else _to_np_dtype(self.DTYPE)) assert info.bits == self.DTYPE.itemsize*8 assert info.nexp == dtypes.finfo(self.DTYPE)[0] assert info.nmant == dtypes.finfo(self.DTYPE)[1] @@ -299,10 +302,10 @@ class TestBitCast(unittest.TestCase): @given(strat.sampled_from(dtype_ints + dtype_floats), strat.sampled_from(dtype_ints + dtype_floats)) def test_shape_change_bitcast(self, dt1, dt2): # NOTE: this has to be assume to prevent hypothesis from skipping all samples - assume(dt2 != dtypes.bfloat16 and dt1 != dtypes.bfloat16) # no test for bf16 bitcast yet assume(not (getenv("PTX") and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX data = rand_for_dtype(dt1, 32).reshape(2, 2, 8) - _test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, data.view(_to_np_dtype(dt2)).tolist()) + expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2)) + _test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist()) def test_shape_change_bitcast_exceptions(self): with self.assertRaises(RuntimeError): @@ -342,6 +345,9 @@ class TestUint64DType(TestDType): class TestBoolDType(TestDType): DTYPE = dtypes.bool +@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}") +class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16 + class TestPtrDType(unittest.TestCase): def test_vec_double(self): dt1 = dtypes.float.vec(4).ptr().vec(4) @@ -414,7 +420,7 @@ class TestDtypeUsage(unittest.TestCase): t = Tensor([[1, 2], [3, 4]], dtype=d) (t*t).max().item() -@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16) or Device.DEFAULT == "PYTHON", f"no bfloat16 on {Device.DEFAULT}") +@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}") class TestOpsBFloat16(unittest.TestCase): def test_cast(self): # TODO: helper_test_op breaks in unrelated part diff --git a/tinygrad/device.py b/tinygrad/device.py index c7e5cabdf6..6380660010 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -304,7 +304,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: if device == "METAL": return not CI if device in {"CUDA", "NV"}: return not CI and not getenv("PTX") if device in {"CPU", "LLVM"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} - return device == "AMD" + return device in {"AMD", "PYTHON"} if dtype in dtypes.fp8s: # not supported yet - in progress return False diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 999ceff4ac..cd2ab98651 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -298,6 +298,7 @@ truncate: dict[DType, Callable] = {dtypes.bool: bool, def _to_np_dtype(dtype:DType) -> type|None: import numpy as np + if dtype == dtypes.bfloat16: return np.float32 return np.dtype(dtype.fmt).type if dtype.fmt is not None else None def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 import numpy as np @@ -306,6 +307,8 @@ def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # @functools.cache def _to_torch_dtype(dtype:DType) -> 'torch.dtype'|None: # type: ignore [name-defined] # noqa: F821 import numpy as np, torch + if dtype == dtypes.uint64: return torch.uint64 + if dtype == dtypes.bfloat16: return torch.bfloat16 # NOTE: torch doesn't expose this mapping with a stable API try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype except TypeError: return None diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 991a338785..cd72722f05 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -4,25 +4,35 @@ # this is the (living) definition of uops from typing import Any, TYPE_CHECKING import pickle, base64, itertools, time, struct, sys -from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate +from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16 from tinygrad.helpers import all_same, getenv, flatten, get_single_element from tinygrad.device import Compiled, Compiler, Allocator from tinygrad.codegen.opt import tc from tinygrad.uop.ops import exec_alu, Ops, UOp, GroupOp from tinygrad.renderer import Renderer -def _load(m, i): +def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else dtype.fmt + +def to_storage_scalar(x, dtype: DType): + if dtype == dtypes.bfloat16: return (struct.unpack('I', struct.pack('f', float_to_bf16(x)))[0] >> 16) & 0xFFFF + return x + +def from_storage_scalar(x, dtype: DType): + if dtype == dtypes.bfloat16: return struct.unpack('f', struct.pack('I', (x & 0xFFFF) << 16))[0] + return x + +def _load(m, i, dtype: DType): if i is None: return 0.0 if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}") - return m[i] + return from_storage_scalar(m[i], dtype) -def load(inp, j=0): - if len(inp) == 2: return [_load(m, x+j if x is not None else None) if gate else default for (m,x,gate),default in zip(*inp)] - return [_load(m, x+j if x is not None else None) for m,x,_ in inp[0]] +def load(inp, j, dtype: DType): + if len(inp) == 2: return [_load(m, x+j if x is not None else None, dtype) if gate else default for (m,x,gate),default in zip(*inp)] + return [_load(m, x+j if x is not None else None, dtype) for m,x,_ in inp[0]] -def _store(m, i, v): +def _store(m, i, v, dtype: DType): if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}") - m[i] = v + m[i] = to_storage_scalar(v, dtype) class PythonProgram: def __init__(self, name:str, lib:bytes): @@ -57,19 +67,20 @@ class PythonProgram: if uop is Ops.STORE: for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]): for (m,o,g),v in zip(inp[0], val): - if g: _store(m, o+j, v) + if g: _store(m, o+j, v, dtp[1].scalar()) i += 1 continue if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: assert isinstance(dtype, PtrDType), dtype - if dtype.fmt is None: raise RuntimeError(f"{dtype=} is not supported") - if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e" + storage_fmt = storage_fmt_for_dtype(dtype.base.scalar()) + if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported") + if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e" if uop is Ops.DEFINE_REG: # REGs are per thread - ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(dtype.fmt) for _ in range(warp_size)] + ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)] else: buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.DEFINE_GLOBAL else pbufs.pop(0) - ul[i] = [buf.cast(dtype.fmt)] * warp_size + ul[i] = [buf.cast(storage_fmt)] * warp_size elif uop is Ops.DEFINE_VAR: ul[i] = [pvals.pop(0)] * warp_size elif uop is Ops.SPECIAL: @@ -98,16 +109,17 @@ class PythonProgram: continue elif uop is Ops.VECTORIZE: ul[i] = inp elif uop is Ops.BITCAST: - assert dtp[0].fmt and dtype.fmt - pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt - ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0]))) + packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(dtp[0].scalar()), *[to_storage_scalar(x, dtp[0].scalar()) for x in inp[0]]) + ul[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed)) + ul[i] = [from_storage_scalar(x, dtype.scalar()) for x in ul[i]] elif uop is Ops.CAST: ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]] elif uop is Ops.LOAD: if dtype.count > 1: - ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)] + ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j, dtype.scalar()) \ + for j in range(dtype.count)] else: - ul[i] = load(inp) + ul[i] = load(inp, 0, dtype) elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)] elif uop is Ops.WMMA: # here are the models for the WMMA instruction on the different hardware