mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Use is_dtype_supported in more places in tests (#7529)
This commit is contained in:
@@ -454,6 +454,7 @@ class TestTypeSpec(unittest.TestCase):
|
||||
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
|
||||
shell=True, check=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int8), f"no int8 on {Device.DEFAULT}")
|
||||
def test_dtype_str_arg(self):
|
||||
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
|
||||
tested = 0
|
||||
@@ -484,7 +485,8 @@ class TestTypeSpec(unittest.TestCase):
|
||||
|
||||
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
|
||||
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
|
||||
|
||||
@@ -493,20 +495,23 @@ class TestTypeSpec(unittest.TestCase):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
||||
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
|
||||
|
||||
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
|
||||
|
||||
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
|
||||
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
@@ -526,8 +531,10 @@ class TestTypeSpec(unittest.TestCase):
|
||||
_assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
|
||||
_assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
|
||||
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
|
||||
if is_dtype_supported(dtypes.int16):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
|
||||
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7))
|
||||
@@ -839,8 +846,9 @@ class TestTensorMethod(unittest.TestCase):
|
||||
class TestDtypeUsage(unittest.TestCase):
|
||||
def test_max_w_alu(self):
|
||||
for d in dtypes.ints:
|
||||
t = Tensor([[1, 2], [3, 4]], dtype=d)
|
||||
(t*t).max().item()
|
||||
if is_dtype_supported(d):
|
||||
t = Tensor([[1, 2], [3, 4]], dtype=d)
|
||||
(t*t).max().item()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user