mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix cast HALF with PYTHON backend (#14058)
This commit is contained in:
@@ -43,9 +43,6 @@ def _test_cast(a:Tensor, target_dtype:DType):
|
||||
if a.is_floating_point() and dtypes.is_unsigned(target_dtype):
|
||||
# converting negative float to unsigned integer is undefined
|
||||
a = a.abs()
|
||||
if target_dtype == dtypes.half and Device.DEFAULT == "PYTHON":
|
||||
# TODO: struct.pack cannot pack value > 65504 (max of half) into e format
|
||||
a = (a > 65504).where(65504, a)
|
||||
|
||||
expected = list(a.numpy().astype(_to_np_dtype(target_dtype)))
|
||||
if target_dtype in dtypes.fp8s: expected = [truncate[target_dtype](x) for x in expected]
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# this is the (living) definition of uops
|
||||
from typing import Any, TYPE_CHECKING, cast
|
||||
import pickle, base64, itertools, time, struct, sys, functools
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16, float_to_fp8, fp8_to_float
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_fp16, float_to_bf16, float_to_fp8, fp8_to_float
|
||||
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, EMULATE
|
||||
from tinygrad.device import Compiled, Compiler, Allocator, CompilerSet, CompilerPair
|
||||
from tinygrad.codegen.opt import tc
|
||||
@@ -14,6 +14,7 @@ from tinygrad.renderer import Renderer
|
||||
def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else 'B' if dtype in dtypes.fp8s else dtype.fmt
|
||||
|
||||
def to_storage_scalar(x, dtype: DType):
|
||||
if dtype == dtypes.half: return float_to_fp16(x)
|
||||
if dtype == dtypes.bfloat16: return (struct.unpack('I', struct.pack('f', float_to_bf16(x)))[0] >> 16) & 0xFFFF
|
||||
if dtype in dtypes.fp8s: return float_to_fp8(float(x), dtype)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user