mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
move some test_dtype tests to unit (#11479)
This commit is contained in:
@@ -4,7 +4,7 @@ import torch
|
||||
from typing import Any, List
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, DEBUG, CI
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_dtype, to_dtype, fp8_to_float, float_to_fp8
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from hypothesis import assume, given, settings, strategies as strat
|
||||
@@ -384,30 +384,6 @@ class TestPtrDType(unittest.TestCase):
|
||||
self.assertEqual(dt.v, 4)
|
||||
self.assertEqual(dt.count, 4)
|
||||
|
||||
class TestImageDType(unittest.TestCase):
|
||||
def test_image_scalar(self):
|
||||
assert dtypes.imagef((10,10)).base.scalar() == dtypes.float32
|
||||
assert dtypes.imageh((10,10)).base.scalar() == dtypes.float32
|
||||
def test_image_vec(self):
|
||||
assert dtypes.imagef((10,10)).base.vec(4) == dtypes.float32.vec(4)
|
||||
assert dtypes.imageh((10,10)).base.vec(4) == dtypes.float32.vec(4)
|
||||
|
||||
class TestEqStrDType(unittest.TestCase):
|
||||
def test_image_ne(self):
|
||||
if ImageDType is None: raise unittest.SkipTest("no ImageDType support")
|
||||
assert dtypes.float == dtypes.float32, "float doesn't match?"
|
||||
assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match"
|
||||
assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match"
|
||||
assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches"
|
||||
assert isinstance(dtypes.imageh((1,2,4)), ImageDType)
|
||||
def test_ptr_eq(self):
|
||||
assert dtypes.float32.ptr() == dtypes.float32.ptr()
|
||||
assert not (dtypes.float32.ptr() != dtypes.float32.ptr())
|
||||
def test_strs(self):
|
||||
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
||||
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
|
||||
self.assertEqual(str(dtypes.float32.ptr(16)), "dtypes.float.ptr(16)")
|
||||
|
||||
class TestImplicitFunctionTypeChange(unittest.TestCase):
|
||||
def test_functions(self):
|
||||
result = []
|
||||
@@ -438,19 +414,6 @@ class TestDtypeUsage(unittest.TestCase):
|
||||
t = Tensor([[1, 2], [3, 4]], dtype=d)
|
||||
(t*t).max().item()
|
||||
|
||||
class TestToDtype(unittest.TestCase):
|
||||
def test_dtype_to_dtype(self):
|
||||
dtype = dtypes.int32
|
||||
res = to_dtype(dtype)
|
||||
self.assertIsInstance(res, DType)
|
||||
self.assertEqual(res, dtypes.int32)
|
||||
|
||||
def test_str_to_dtype(self):
|
||||
dtype = "int32"
|
||||
res = to_dtype(dtype)
|
||||
self.assertIsInstance(res, DType)
|
||||
self.assertEqual(res, dtypes.int32)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
|
||||
class TestOpsBFloat16(unittest.TestCase):
|
||||
def test_cast(self):
|
||||
|
||||
42
test/unit/test_dtype.py
Normal file
42
test/unit/test_dtype.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import unittest
|
||||
from tinygrad.dtype import dtypes, DType, ImageDType, PtrDType, to_dtype
|
||||
|
||||
class TestImageDType(unittest.TestCase):
|
||||
def test_image_scalar(self):
|
||||
assert dtypes.imagef((10,10)).base.scalar() == dtypes.float32
|
||||
assert dtypes.imageh((10,10)).base.scalar() == dtypes.float32
|
||||
def test_image_vec(self):
|
||||
assert dtypes.imagef((10,10)).base.vec(4) == dtypes.float32.vec(4)
|
||||
assert dtypes.imageh((10,10)).base.vec(4) == dtypes.float32.vec(4)
|
||||
|
||||
class TestEqStrDType(unittest.TestCase):
|
||||
def test_image_ne(self):
|
||||
if ImageDType is None: raise unittest.SkipTest("no ImageDType support")
|
||||
assert dtypes.float == dtypes.float32, "float doesn't match?"
|
||||
assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match"
|
||||
assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match"
|
||||
assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches"
|
||||
assert isinstance(dtypes.imageh((1,2,4)), ImageDType)
|
||||
def test_ptr_eq(self):
|
||||
assert dtypes.float32.ptr() == dtypes.float32.ptr()
|
||||
assert not (dtypes.float32.ptr() != dtypes.float32.ptr())
|
||||
def test_strs(self):
|
||||
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
||||
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
|
||||
self.assertEqual(str(dtypes.float32.ptr(16)), "dtypes.float.ptr(16)")
|
||||
|
||||
class TestToDtype(unittest.TestCase):
|
||||
def test_dtype_to_dtype(self):
|
||||
dtype = dtypes.int32
|
||||
res = to_dtype(dtype)
|
||||
self.assertIsInstance(res, DType)
|
||||
self.assertEqual(res, dtypes.int32)
|
||||
|
||||
def test_str_to_dtype(self):
|
||||
dtype = "int32"
|
||||
res = to_dtype(dtype)
|
||||
self.assertIsInstance(res, DType)
|
||||
self.assertEqual(res, dtypes.int32)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user