mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
handle float16 overflow in PYTHON (#5022)
* handle float16 overflow in PYTHON use `truncate` when constructing tensor from list to make sure all values are packable (might be slow, but should be correct). add truncate_fp16 to cast overflowed values to inf/-inf. * all valid fmt supports truncate
This commit is contained in:
@@ -544,7 +544,6 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "TODO: support inf to half in PYTHON backend")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
|
||||
def test_sum_acc_dtype(self):
|
||||
t = Tensor([40000, 40000], dtype=dtypes.float16)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest, copy, mmap, random
|
||||
import unittest, copy, mmap, random, math
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.helpers import getenv, temp, CI
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
@@ -302,6 +303,22 @@ class TestTinygrad(unittest.TestCase):
|
||||
data = _generate_data(depth)
|
||||
np.testing.assert_allclose(Tensor(data).numpy(), np.array(data))
|
||||
|
||||
def test_tensor_list_special_values(self):
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
data = [math.nan, -math.inf, 65504, 65519, 65519.999, 65520, 65520.1]
|
||||
data = data + [-x for x in data]
|
||||
np.testing.assert_allclose(Tensor(data, dtype=dtypes.float16).numpy(), np.array(data, dtype=np.float16))
|
||||
|
||||
# uint32
|
||||
data = [1 << 33, 1 << 32, 1 << 32 - 1, 1]
|
||||
data = data + [-x for x in data]
|
||||
np.testing.assert_allclose(Tensor(data, dtype=dtypes.uint32).numpy(), np.array(data, dtype=np.uint32))
|
||||
|
||||
# int32
|
||||
data = [1 << 33, 1 << 32, 1 << 32 - 1, 1]
|
||||
data = data + [-x for x in data]
|
||||
np.testing.assert_allclose(Tensor(data, dtype=dtypes.int32).numpy(), np.array(data, dtype=np.int32))
|
||||
|
||||
def test_tensor_bytes(self):
|
||||
data = b"abc123"
|
||||
t = Tensor(data)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Union, Tuple, Any, List, Dict, Callable
|
||||
import functools, hashlib, math, operator, ctypes
|
||||
import functools, hashlib, math, operator, ctypes, struct
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import prod, dedup
|
||||
@@ -128,9 +128,16 @@ python_alu = {
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
|
||||
TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
def truncate_fp16(x):
|
||||
try:
|
||||
x = float(x)
|
||||
struct.pack("@e", x)
|
||||
return x
|
||||
except OverflowError: return x * math.inf
|
||||
|
||||
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
||||
# TODO: float16 and bfloat16?
|
||||
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
# TODO: bfloat16
|
||||
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
||||
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
||||
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
|
||||
|
||||
@@ -110,6 +110,8 @@ class PythonProgram:
|
||||
if dtypes.is_int(dtype):
|
||||
overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
|
||||
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
|
||||
elif dtypes.is_float(dtype):
|
||||
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
|
||||
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
|
||||
elif uop is UOps.LOAD:
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.ops import LoadOps, truncate
|
||||
from tinygrad.device import Device, Buffer, BufferOptions
|
||||
from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -51,7 +51,8 @@ def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
|
||||
else:
|
||||
ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON")
|
||||
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
||||
data = struct.pack(f"@{ret.size}{dtype.fmt}", *fully_flatten(x))
|
||||
truncate_function = truncate[dtype]
|
||||
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
||||
# fake realize
|
||||
ret.buffer.allocate(memoryview(data))
|
||||
del ret.srcs
|
||||
|
||||
Reference in New Issue
Block a user