From e22e5da9a53d635f3d7491237aec454d5f1ba517 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 2 Aug 2025 12:25:00 -0700 Subject: [PATCH] move some test_dtype tests to unit (#11479) --- test/test_dtype.py | 39 +------------------------------------- test/unit/test_dtype.py | 42 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 38 deletions(-) create mode 100644 test/unit/test_dtype.py diff --git a/test/test_dtype.py b/test/test_dtype.py index a39fd9927e..6a3d7d5a3a 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -4,7 +4,7 @@ import torch from typing import Any, List from tinygrad.device import is_dtype_supported from tinygrad.helpers import getenv, DEBUG, CI -from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_dtype, to_dtype, fp8_to_float, float_to_fp8 +from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8 from tinygrad import Device, Tensor, dtypes from tinygrad.tensor import _to_np_dtype from hypothesis import assume, given, settings, strategies as strat @@ -384,30 +384,6 @@ class TestPtrDType(unittest.TestCase): self.assertEqual(dt.v, 4) self.assertEqual(dt.count, 4) -class TestImageDType(unittest.TestCase): - def test_image_scalar(self): - assert dtypes.imagef((10,10)).base.scalar() == dtypes.float32 - assert dtypes.imageh((10,10)).base.scalar() == dtypes.float32 - def test_image_vec(self): - assert dtypes.imagef((10,10)).base.vec(4) == dtypes.float32.vec(4) - assert dtypes.imageh((10,10)).base.vec(4) == dtypes.float32.vec(4) - -class TestEqStrDType(unittest.TestCase): - def test_image_ne(self): - if ImageDType is None: raise unittest.SkipTest("no ImageDType support") - assert dtypes.float == dtypes.float32, "float doesn't match?" - assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match" - assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match" - assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches" - assert isinstance(dtypes.imageh((1,2,4)), ImageDType) - def test_ptr_eq(self): - assert dtypes.float32.ptr() == dtypes.float32.ptr() - assert not (dtypes.float32.ptr() != dtypes.float32.ptr()) - def test_strs(self): - if PtrDType is None: raise unittest.SkipTest("no PtrDType support") - self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))") - self.assertEqual(str(dtypes.float32.ptr(16)), "dtypes.float.ptr(16)") - class TestImplicitFunctionTypeChange(unittest.TestCase): def test_functions(self): result = [] @@ -438,19 +414,6 @@ class TestDtypeUsage(unittest.TestCase): t = Tensor([[1, 2], [3, 4]], dtype=d) (t*t).max().item() -class TestToDtype(unittest.TestCase): - def test_dtype_to_dtype(self): - dtype = dtypes.int32 - res = to_dtype(dtype) - self.assertIsInstance(res, DType) - self.assertEqual(res, dtypes.int32) - - def test_str_to_dtype(self): - dtype = "int32" - res = to_dtype(dtype) - self.assertIsInstance(res, DType) - self.assertEqual(res, dtypes.int32) - @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}") class TestOpsBFloat16(unittest.TestCase): def test_cast(self): diff --git a/test/unit/test_dtype.py b/test/unit/test_dtype.py new file mode 100644 index 0000000000..0d9de293cc --- /dev/null +++ b/test/unit/test_dtype.py @@ -0,0 +1,42 @@ +import unittest +from tinygrad.dtype import dtypes, DType, ImageDType, PtrDType, to_dtype + +class TestImageDType(unittest.TestCase): + def test_image_scalar(self): + assert dtypes.imagef((10,10)).base.scalar() == dtypes.float32 + assert dtypes.imageh((10,10)).base.scalar() == dtypes.float32 + def test_image_vec(self): + assert dtypes.imagef((10,10)).base.vec(4) == dtypes.float32.vec(4) + assert dtypes.imageh((10,10)).base.vec(4) == dtypes.float32.vec(4) + +class TestEqStrDType(unittest.TestCase): + def test_image_ne(self): + if ImageDType is None: raise unittest.SkipTest("no ImageDType support") + assert dtypes.float == dtypes.float32, "float doesn't match?" + assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match" + assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match" + assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches" + assert isinstance(dtypes.imageh((1,2,4)), ImageDType) + def test_ptr_eq(self): + assert dtypes.float32.ptr() == dtypes.float32.ptr() + assert not (dtypes.float32.ptr() != dtypes.float32.ptr()) + def test_strs(self): + if PtrDType is None: raise unittest.SkipTest("no PtrDType support") + self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))") + self.assertEqual(str(dtypes.float32.ptr(16)), "dtypes.float.ptr(16)") + +class TestToDtype(unittest.TestCase): + def test_dtype_to_dtype(self): + dtype = dtypes.int32 + res = to_dtype(dtype) + self.assertIsInstance(res, DType) + self.assertEqual(res, dtypes.int32) + + def test_str_to_dtype(self): + dtype = "int32" + res = to_dtype(dtype) + self.assertIsInstance(res, DType) + self.assertEqual(res, dtypes.int32) + +if __name__ == "__main__": + unittest.main()