clean up test_dtype.py (#2827)

make is_dtype_supported a pure function and clean up long lines
This commit is contained in:
chenyu
2023-12-18 16:06:09 -05:00
committed by GitHub
parent 20ea43b6e7
commit 264fe9c93f

View File

@@ -1,4 +1,3 @@
# ruff: noqa: E501
import unittest
import numpy as np
import torch
@@ -8,19 +7,20 @@ from tinygrad.tensor import Tensor, dtypes
from typing import Any, List
from hypothesis import given, settings, strategies as st
def is_dtype_supported(dtype: DType):
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
# for GPU, cl_khr_fp16 isn't supported
# for LLVM, it segfaults because it can't link to the casting function
if dtype == dtypes.half: return not (CI and Device.DEFAULT in ["GPU", "LLVM"]) and Device.DEFAULT != "WEBGPU" and getenv("CUDACPU") != 1
# CUDA in CI uses CUDACPU that does not support half
if dtype == dtypes.half: return not (CI and device in ["GPU", "LLVM", "CUDA"]) and device != "WEBGPU"
if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType
if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and (not OSX and Device.DEFAULT == "GPU")
if dtype in [dtypes.int8, dtypes.uint8]: return Device.DEFAULT not in ["WEBGPU"]
if dtype in [dtypes.int16, dtypes.uint16]: return Device.DEFAULT not in ["WEBGPU", "TORCH"]
if dtype == dtypes.uint32: return Device.DEFAULT not in ["TORCH"]
if dtype in [dtypes.int64, dtypes.uint64]: return Device.DEFAULT not in ["WEBGPU", "TORCH"]
if dtype == dtypes.bool:
# host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable
if Device.DEFAULT == "WEBGPU": return False
# TODO: is this correct? it reduces to only GPU on non-OSX
if dtype == dtypes.float64: return device not in ["WEBGPU", "METAL"] and (not OSX and device == "GPU")
if dtype in [dtypes.int8, dtypes.uint8]: return device not in ["WEBGPU"]
if dtype in [dtypes.int16, dtypes.uint16]: return device not in ["WEBGPU", "TORCH"]
if dtype == dtypes.uint32: return device not in ["TORCH"]
if dtype in [dtypes.int64, dtypes.uint64]: return device not in ["WEBGPU", "TORCH"]
# for WEBGPU, host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable
if dtype == dtypes.bool: return device != "WEBGPU"
return True
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
@@ -45,9 +45,12 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np)))
def _test_bitcast(a:Tensor, target_dtype:DType, target=None): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
def _test_op(fxn, target_dtype:DType, target):
_assert_eq(fxn(), target_dtype, target)
def _test_cast(a:Tensor, target_dtype:DType):
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np)))
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
class TestDType(unittest.TestCase):
DTYPE: Any = None
@@ -55,7 +58,9 @@ class TestDType(unittest.TestCase):
@classmethod
def setUpClass(cls):
if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist()
if dtypes.is_int(cls.DTYPE): cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist()
elif cls.DTYPE == dtypes.bool: cls.DATA = np.random.choice([True, False], size=10).tolist()
else: cls.DATA = np.random.uniform(0, 1, size=10).tolist()
def setUp(self):
if self.DTYPE is None: raise unittest.SkipTest("base class")
@@ -88,7 +93,8 @@ class TestDType(unittest.TestCase):
def test_bitcast(self):
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
list(map(
lambda dtype: _test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None,
lambda dtype:
_test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None,
get_available_cast_dtypes(self.DTYPE)
))
@@ -101,9 +107,8 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "TORCH"], "bfloat16 not supported")
class TestBFloat16DType(unittest.TestCase):
def setUp(self):
if Device.DEFAULT not in ["LLVM", "TORCH"]: raise unittest.SkipTest("bfloat16 not supported")
def test_bf16_to_float(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
@@ -142,12 +147,14 @@ class TestDoubleDtype(TestDType): DTYPE = dtypes.double
class TestInt8Dtype(TestDType):
DTYPE = dtypes.int8
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
def test_int8_to_uint8_negative(self):
_test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
class TestUint8Dtype(TestDType):
DTYPE = dtypes.uint8
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
def test_uint8_to_int8_overflow(self):
_test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
class TestBitCast(unittest.TestCase):
def test_shape_change_bitcast(self):