fix cast HALF with PYTHON backend (#14058)

This commit is contained in:
chenyu
2026-01-07 16:52:05 -05:00
committed by GitHub
parent 5f1ede7f7e
commit 3caa1e2c98
2 changed files with 2 additions and 4 deletions

View File

@@ -43,9 +43,6 @@ def _test_cast(a:Tensor, target_dtype:DType):
if a.is_floating_point() and dtypes.is_unsigned(target_dtype):
# converting negative float to unsigned integer is undefined
a = a.abs()
if target_dtype == dtypes.half and Device.DEFAULT == "PYTHON":
# TODO: struct.pack cannot pack value > 65504 (max of half) into e format
a = (a > 65504).where(65504, a)
expected = list(a.numpy().astype(_to_np_dtype(target_dtype)))
if target_dtype in dtypes.fp8s: expected = [truncate[target_dtype](x) for x in expected]

View File

@@ -4,7 +4,7 @@
# this is the (living) definition of uops
from typing import Any, TYPE_CHECKING, cast
import pickle, base64, itertools, time, struct, sys, functools
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16, float_to_fp8, fp8_to_float
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_fp16, float_to_bf16, float_to_fp8, fp8_to_float
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, EMULATE
from tinygrad.device import Compiled, Compiler, Allocator, CompilerSet, CompilerPair
from tinygrad.codegen.opt import tc
@@ -14,6 +14,7 @@ from tinygrad.renderer import Renderer
def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else 'B' if dtype in dtypes.fp8s else dtype.fmt
def to_storage_scalar(x, dtype: DType):
if dtype == dtypes.half: return float_to_fp16(x)
if dtype == dtypes.bfloat16: return (struct.unpack('I', struct.pack('f', float_to_bf16(x)))[0] >> 16) & 0xFFFF
if dtype in dtypes.fp8s: return float_to_fp8(float(x), dtype)
return x