From dca5e4fe74275f1ff1bc96de5613e02e717aef35 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 25 Dec 2023 19:38:47 +0200 Subject: [PATCH] tensor == tensor should be bool (#2916) * return bool * add tests to the type spec * fix multinomial * fix tril * fix round * fix NegativeLogLikelihoodLoss * rm debug * webgpu * more webgpu * bitwise or for adding two bools * onnx ops dont need to cast anymore * Revert "bitwise or for adding two bools" This reverts commit b413babffa4d93c5cc94a252cb7086b9a899a437. * workaround for metal neg * just the tests in the type spec --- test/test_dtype.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index e1d4e3166b..fb50a1713e 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,6 +1,7 @@ import unittest import numpy as np import torch +import operator from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, least_upper_float, temp, least_upper_dtype from tinygrad import Device from tinygrad.tensor import Tensor, dtypes @@ -9,7 +10,6 @@ from hypothesis import given, settings, strategies as st core_dtypes = list(DTYPES_DICT.values()) floats = [dt for dt in core_dtypes if dtypes.is_float(dt)] - def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): # for GPU, cl_khr_fp16 isn't supported # for LLVM, it segfaults because it can't link to the casting function @@ -308,6 +308,11 @@ class TestTypeSpec(unittest.TestCase): assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float + @unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't follow the bool ops spec") + @given(st.sampled_from(core_dtypes), st.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne])) + def test_bool_ops(self, dtype, op): + assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool + @given(st.sampled_from(core_dtypes), st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) def test_functions_return_index(self, dtype, default_int, default_float):