fix dtypes helpers for integers (#2716)

* scalar

* maybe do this instead

* Revert "scalar"

everything is a scalar

* add tests in test_dtype

* fuzz testing + fix unsigned ints

* fuzz everything
This commit is contained in:
qazal
2023-12-11 19:28:19 +02:00
committed by GitHub
parent bc3c4ce50b
commit a43bc78804
2 changed files with 41 additions and 5 deletions

View File

@@ -4,6 +4,7 @@ from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType,
from tinygrad import Device
from tinygrad.tensor import Tensor, dtypes
from typing import Any, List
from hypothesis import given, strategies as st
def is_dtype_supported(dtype: DType):
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
@@ -49,7 +50,7 @@ class TestDType(unittest.TestCase):
DATA: Any = None
@classmethod
def setUpClass(cls):
if not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist()
def setUp(self):
if self.DTYPE is None: raise unittest.SkipTest("base class")
@@ -189,5 +190,33 @@ class TestEqStrDType(unittest.TestCase):
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")
class TestHelpers(unittest.TestCase):
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
uints = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
floats = (dtypes.float16, dtypes.float32, dtypes.float64)
@given(st.sampled_from(signed_ints+uints), st.integers(min_value=1, max_value=8))
def test_is_int(self, dtype, amt):
assert dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
@given(st.sampled_from(uints), st.integers(min_value=1, max_value=8))
def test_is_unsigned_uints(self, dtype, amt):
assert dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
@given(st.sampled_from(signed_ints), st.integers(min_value=1, max_value=8))
def test_is_unsigned_signed_ints(self, dtype, amt):
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
@given(st.sampled_from(floats), st.integers(min_value=1, max_value=8))
def test_is_float(self, dtype, amt):
assert dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
@given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), st.integers(min_value=2, max_value=8))
def test_scalar(self, dtype, amt):
assert dtype.vec(amt).scalar() == dtype
if __name__ == '__main__':
unittest.main()