diff --git a/test/test_dtype.py b/test/test_dtype.py index 5b34d11989..f0337de3d7 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -43,9 +43,6 @@ def _test_cast(a:Tensor, target_dtype:DType): if a.is_floating_point() and dtypes.is_unsigned(target_dtype): # converting negative float to unsigned integer is undefined a = a.abs() - if target_dtype == dtypes.half and Device.DEFAULT == "PYTHON": - # TODO: struct.pack cannot pack value > 65504 (max of half) into e format - a = (a > 65504).where(65504, a) expected = list(a.numpy().astype(_to_np_dtype(target_dtype))) if target_dtype in dtypes.fp8s: expected = [truncate[target_dtype](x) for x in expected] diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index aed667fb2b..52827615d5 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -4,7 +4,7 @@ # this is the (living) definition of uops from typing import Any, TYPE_CHECKING, cast import pickle, base64, itertools, time, struct, sys, functools -from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16, float_to_fp8, fp8_to_float +from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_fp16, float_to_bf16, float_to_fp8, fp8_to_float from tinygrad.helpers import all_same, getenv, flatten, get_single_element, EMULATE from tinygrad.device import Compiled, Compiler, Allocator, CompilerSet, CompilerPair from tinygrad.codegen.opt import tc @@ -14,6 +14,7 @@ from tinygrad.renderer import Renderer def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else 'B' if dtype in dtypes.fp8s else dtype.fmt def to_storage_scalar(x, dtype: DType): + if dtype == dtypes.half: return float_to_fp16(x) if dtype == dtypes.bfloat16: return (struct.unpack('I', struct.pack('f', float_to_bf16(x)))[0] >> 16) & 0xFFFF if dtype in dtypes.fp8s: return float_to_fp8(float(x), dtype) return x