mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix dtypes helpers for integers (#2716)
* scalar * maybe do this instead * Revert "scalar" everything is a scalar * add tests in test_dtype * fuzz testing + fix unsigned ints * fuzz everything
This commit is contained in:
@@ -4,6 +4,7 @@ from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType,
|
||||
from tinygrad import Device
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
from typing import Any, List
|
||||
from hypothesis import given, strategies as st
|
||||
|
||||
def is_dtype_supported(dtype: DType):
|
||||
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
|
||||
@@ -49,7 +50,7 @@ class TestDType(unittest.TestCase):
|
||||
DATA: Any = None
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
|
||||
if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
|
||||
cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist()
|
||||
def setUp(self):
|
||||
if self.DTYPE is None: raise unittest.SkipTest("base class")
|
||||
@@ -189,5 +190,33 @@ class TestEqStrDType(unittest.TestCase):
|
||||
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
|
||||
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")
|
||||
|
||||
class TestHelpers(unittest.TestCase):
|
||||
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
|
||||
uints = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||
floats = (dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
|
||||
@given(st.sampled_from(signed_ints+uints), st.integers(min_value=1, max_value=8))
|
||||
def test_is_int(self, dtype, amt):
|
||||
assert dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
|
||||
assert not dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
|
||||
|
||||
@given(st.sampled_from(uints), st.integers(min_value=1, max_value=8))
|
||||
def test_is_unsigned_uints(self, dtype, amt):
|
||||
assert dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
|
||||
|
||||
@given(st.sampled_from(signed_ints), st.integers(min_value=1, max_value=8))
|
||||
def test_is_unsigned_signed_ints(self, dtype, amt):
|
||||
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
|
||||
|
||||
@given(st.sampled_from(floats), st.integers(min_value=1, max_value=8))
|
||||
def test_is_float(self, dtype, amt):
|
||||
assert dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
|
||||
assert not dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
|
||||
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
|
||||
|
||||
@given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), st.integers(min_value=2, max_value=8))
|
||||
def test_scalar(self, dtype, amt):
|
||||
assert dtype.vec(amt).scalar() == dtype
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user