diff --git a/test/test_dtype.py b/test/test_dtype.py index e1d4e3166b..fb50a1713e 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,6 +1,7 @@ import unittest import numpy as np import torch +import operator from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, least_upper_float, temp, least_upper_dtype from tinygrad import Device from tinygrad.tensor import Tensor, dtypes @@ -9,7 +10,6 @@ from hypothesis import given, settings, strategies as st core_dtypes = list(DTYPES_DICT.values()) floats = [dt for dt in core_dtypes if dtypes.is_float(dt)] - def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): # for GPU, cl_khr_fp16 isn't supported # for LLVM, it segfaults because it can't link to the casting function @@ -308,6 +308,11 @@ class TestTypeSpec(unittest.TestCase): assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float + @unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't follow the bool ops spec") + @given(st.sampled_from(core_dtypes), st.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne])) + def test_bool_ops(self, dtype, op): + assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool + @given(st.sampled_from(core_dtypes), st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) def test_functions_return_index(self, dtype, default_int, default_float):