mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add unit tests for to_dtype (#8217)
* add unit test for to_dtype * add unit test for to_dtype --------- Co-authored-by: pkotzbach <pawkotz@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import torch
|
||||
from typing import Any, List
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, DEBUG, CI
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16, to_dtype
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
@@ -854,5 +854,18 @@ class TestDtypeUsage(unittest.TestCase):
|
||||
t = Tensor([[1, 2], [3, 4]], dtype=d)
|
||||
(t*t).max().item()
|
||||
|
||||
class TestToDtype(unittest.TestCase):
|
||||
def test_dtype_to_dtype(self):
|
||||
dtype = dtypes.int32
|
||||
res = to_dtype(dtype)
|
||||
self.assertIsInstance(res, DType)
|
||||
self.assertEqual(res, dtypes.int32)
|
||||
|
||||
def test_str_to_dtype(self):
|
||||
dtype = "int32"
|
||||
res = to_dtype(dtype)
|
||||
self.assertIsInstance(res, DType)
|
||||
self.assertEqual(res, dtypes.int32)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user