diff --git a/test/test_dtype.py b/test/test_dtype.py index 231dd704c9..051590c24e 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,3 +1,4 @@ +import operator import unittest import numpy as np import torch @@ -7,6 +8,9 @@ from tinygrad.tensor import Tensor, dtypes from typing import Any, List from hypothesis import given, settings, strategies as st +core_types = list(DTYPES_DICT.values()) +floats = [dt for dt in core_types if dtypes.is_float(dt)] + def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): # for GPU, cl_khr_fp16 isn't supported # for LLVM, it segfaults because it can't link to the casting function @@ -305,8 +309,10 @@ class TestTypeSpec(unittest.TestCase): assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float -core_types = list(DTYPES_DICT.values()) -floats = [dt for dt in core_types if dtypes.is_float(dt)] + @given(st.sampled_from(core_types), st.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne])) + def test_bool_ops(self, dtype, op): + assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool + class TestTypePromotion(unittest.TestCase): @given(st.sampled_from(core_types)) def test_self_promo_to_self(self, dtype):