some unary functions cast int input into float (#2740)

* some unary functions cast int input into float

* precision

* image dtype
This commit is contained in:
chenyu
2023-12-13 00:10:29 -05:00
committed by GitHub
parent 3e778fcc52
commit 2ef33abd20
3 changed files with 30 additions and 8 deletions

View File

@@ -1,11 +1,12 @@
# ruff: noqa: E501
import unittest
import numpy as np
import torch
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp, least_upper_dtype
from tinygrad import Device
from tinygrad.tensor import Tensor, dtypes
from typing import Any, List
from hypothesis import given, strategies as st
from hypothesis import given, settings, strategies as st
def is_dtype_supported(dtype: DType):
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
@@ -267,6 +268,25 @@ class TestTypePromotion(unittest.TestCase):
assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16
assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16
class TestAutoCastType(unittest.TestCase):
@given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_int(d) and is_dtype_supported(d)]))
@settings(deadline=None)
def test_int_to_float_unary_func(self, dtype):
for func in [
lambda t: t.exp(),
# lambda t: t.exp2(), # requires MUL
lambda t: t.log(),
lambda t: t.log2(),
lambda t: t.sqrt(),
# lambda t: t.rsqrt(), # requires DIV
lambda t: t.sin(),
# lambda t: t.cos(), # requires SUB
# lambda t: t.tan(), # requires .cos() to work
lambda t: t.sigmoid(),
]:
a = [2, 3, 4]
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-4, atol=1e-4)
if __name__ == '__main__':
unittest.main()