mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-07 13:15:01 -05:00
some unary functions cast int input into float (#2740)
* some unary functions cast int input into float * precision * image dtype
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
# ruff: noqa: E501
|
||||
import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp, least_upper_dtype
|
||||
from tinygrad import Device
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
from typing import Any, List
|
||||
from hypothesis import given, strategies as st
|
||||
from hypothesis import given, settings, strategies as st
|
||||
|
||||
def is_dtype_supported(dtype: DType):
|
||||
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
|
||||
@@ -267,6 +268,25 @@ class TestTypePromotion(unittest.TestCase):
|
||||
assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16
|
||||
assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16
|
||||
|
||||
class TestAutoCastType(unittest.TestCase):
|
||||
@given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_int(d) and is_dtype_supported(d)]))
|
||||
@settings(deadline=None)
|
||||
def test_int_to_float_unary_func(self, dtype):
|
||||
for func in [
|
||||
lambda t: t.exp(),
|
||||
# lambda t: t.exp2(), # requires MUL
|
||||
lambda t: t.log(),
|
||||
lambda t: t.log2(),
|
||||
lambda t: t.sqrt(),
|
||||
# lambda t: t.rsqrt(), # requires DIV
|
||||
lambda t: t.sin(),
|
||||
# lambda t: t.cos(), # requires SUB
|
||||
# lambda t: t.tan(), # requires .cos() to work
|
||||
lambda t: t.sigmoid(),
|
||||
]:
|
||||
a = [2, 3, 4]
|
||||
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user