mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
tensor == tensor should be bool (#2916)
* return bool
* add tests to the type spec
* fix multinomial
* fix tril
* fix round
* fix NegativeLogLikelihoodLoss
* rm debug
* webgpu
* more webgpu
* bitwise or for adding two bools
* onnx ops dont need to cast anymore
* Revert "bitwise or for adding two bools"
This reverts commit b413babffa.
* workaround for metal neg
* just the tests in the type spec
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
import operator
|
||||
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, least_upper_float, temp, least_upper_dtype
|
||||
from tinygrad import Device
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
@@ -9,7 +10,6 @@ from hypothesis import given, settings, strategies as st
|
||||
|
||||
core_dtypes = list(DTYPES_DICT.values())
|
||||
floats = [dt for dt in core_dtypes if dtypes.is_float(dt)]
|
||||
|
||||
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
||||
# for GPU, cl_khr_fp16 isn't supported
|
||||
# for LLVM, it segfaults because it can't link to the casting function
|
||||
@@ -308,6 +308,11 @@ class TestTypeSpec(unittest.TestCase):
|
||||
assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float
|
||||
assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't follow the bool ops spec")
|
||||
@given(st.sampled_from(core_dtypes), st.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
|
||||
def test_bool_ops(self, dtype, op):
|
||||
assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool
|
||||
|
||||
@given(st.sampled_from(core_dtypes),
|
||||
st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
|
||||
def test_functions_return_index(self, dtype, default_int, default_float):
|
||||
|
||||
Reference in New Issue
Block a user