add unit tests for to_dtype (#8217)

* add unit test for to_dtype

* add unit test for to_dtype

---------

Co-authored-by: pkotzbach <pawkotz@gmail.com>
This commit is contained in:
pkotzbach
2024-12-13 22:21:02 +01:00
committed by GitHub
parent 8a50868264
commit c1b79c118f

View File

@@ -4,7 +4,7 @@ import torch
from typing import Any, List
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16, to_dtype
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from hypothesis import given, settings, strategies as strat
@@ -854,5 +854,18 @@ class TestDtypeUsage(unittest.TestCase):
t = Tensor([[1, 2], [3, 4]], dtype=d)
(t*t).max().item()
class TestToDtype(unittest.TestCase):
def test_dtype_to_dtype(self):
dtype = dtypes.int32
res = to_dtype(dtype)
self.assertIsInstance(res, DType)
self.assertEqual(res, dtypes.int32)
def test_str_to_dtype(self):
dtype = "int32"
res = to_dtype(dtype)
self.assertIsInstance(res, DType)
self.assertEqual(res, dtypes.int32)
if __name__ == '__main__':
unittest.main()