mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
clean up test_dtype.py (#2827)
make is_dtype_supported a pure function and clean up long lines
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
# ruff: noqa: E501
|
||||
import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -8,19 +7,20 @@ from tinygrad.tensor import Tensor, dtypes
|
||||
from typing import Any, List
|
||||
from hypothesis import given, settings, strategies as st
|
||||
|
||||
def is_dtype_supported(dtype: DType):
|
||||
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
|
||||
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
||||
# for GPU, cl_khr_fp16 isn't supported
|
||||
# for LLVM, it segfaults because it can't link to the casting function
|
||||
if dtype == dtypes.half: return not (CI and Device.DEFAULT in ["GPU", "LLVM"]) and Device.DEFAULT != "WEBGPU" and getenv("CUDACPU") != 1
|
||||
# CUDA in CI uses CUDACPU that does not support half
|
||||
if dtype == dtypes.half: return not (CI and device in ["GPU", "LLVM", "CUDA"]) and device != "WEBGPU"
|
||||
if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType
|
||||
if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and (not OSX and Device.DEFAULT == "GPU")
|
||||
if dtype in [dtypes.int8, dtypes.uint8]: return Device.DEFAULT not in ["WEBGPU"]
|
||||
if dtype in [dtypes.int16, dtypes.uint16]: return Device.DEFAULT not in ["WEBGPU", "TORCH"]
|
||||
if dtype == dtypes.uint32: return Device.DEFAULT not in ["TORCH"]
|
||||
if dtype in [dtypes.int64, dtypes.uint64]: return Device.DEFAULT not in ["WEBGPU", "TORCH"]
|
||||
if dtype == dtypes.bool:
|
||||
# host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable
|
||||
if Device.DEFAULT == "WEBGPU": return False
|
||||
# TODO: is this correct? it reduces to only GPU on non-OSX
|
||||
if dtype == dtypes.float64: return device not in ["WEBGPU", "METAL"] and (not OSX and device == "GPU")
|
||||
if dtype in [dtypes.int8, dtypes.uint8]: return device not in ["WEBGPU"]
|
||||
if dtype in [dtypes.int16, dtypes.uint16]: return device not in ["WEBGPU", "TORCH"]
|
||||
if dtype == dtypes.uint32: return device not in ["TORCH"]
|
||||
if dtype in [dtypes.int64, dtypes.uint64]: return device not in ["WEBGPU", "TORCH"]
|
||||
# for WEBGPU, host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable
|
||||
if dtype == dtypes.bool: return device != "WEBGPU"
|
||||
return True
|
||||
|
||||
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
@@ -45,9 +45,12 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
||||
|
||||
def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
|
||||
def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np)))
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target=None): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
|
||||
def _test_op(fxn, target_dtype:DType, target):
|
||||
_assert_eq(fxn(), target_dtype, target)
|
||||
def _test_cast(a:Tensor, target_dtype:DType):
|
||||
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np)))
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
|
||||
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
|
||||
|
||||
class TestDType(unittest.TestCase):
|
||||
DTYPE: Any = None
|
||||
@@ -55,7 +58,9 @@ class TestDType(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
|
||||
cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist()
|
||||
if dtypes.is_int(cls.DTYPE): cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist()
|
||||
elif cls.DTYPE == dtypes.bool: cls.DATA = np.random.choice([True, False], size=10).tolist()
|
||||
else: cls.DATA = np.random.uniform(0, 1, size=10).tolist()
|
||||
def setUp(self):
|
||||
if self.DTYPE is None: raise unittest.SkipTest("base class")
|
||||
|
||||
@@ -88,7 +93,8 @@ class TestDType(unittest.TestCase):
|
||||
def test_bitcast(self):
|
||||
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
|
||||
list(map(
|
||||
lambda dtype: _test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None,
|
||||
lambda dtype:
|
||||
_test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
|
||||
@@ -101,9 +107,8 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
|
||||
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "TORCH"], "bfloat16 not supported")
|
||||
class TestBFloat16DType(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if Device.DEFAULT not in ["LLVM", "TORCH"]: raise unittest.SkipTest("bfloat16 not supported")
|
||||
def test_bf16_to_float(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
|
||||
@@ -142,12 +147,14 @@ class TestDoubleDtype(TestDType): DTYPE = dtypes.double
|
||||
class TestInt8Dtype(TestDType):
|
||||
DTYPE = dtypes.int8
|
||||
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
|
||||
def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
|
||||
def test_int8_to_uint8_negative(self):
|
||||
_test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
|
||||
|
||||
class TestUint8Dtype(TestDType):
|
||||
DTYPE = dtypes.uint8
|
||||
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
|
||||
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
|
||||
def test_uint8_to_int8_overflow(self):
|
||||
_test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
|
||||
|
||||
class TestBitCast(unittest.TestCase):
|
||||
def test_shape_change_bitcast(self):
|
||||
|
||||
Reference in New Issue
Block a user