return bool

This commit is contained in:
qazal
2023-12-22 21:44:09 +02:00
parent 2783e1b50d
commit bbaaa1289f

View File

@@ -819,10 +819,10 @@ class Tensor:
# in webgpu bool cannot be used as a storage buffer type
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)).cast(dtypes.float if self.device == "WEBGPU" else dtypes.bool)
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True)).cast(dtypes.float if self.device == "WEBGPU" else dtypes.bool)
def __ge__(self, x) -> Tensor: return 1.0-(self<x)
def __le__(self, x) -> Tensor: return 1.0-(self>x)
def __ge__(self, x) -> Tensor: return -(self<x)
def __le__(self, x) -> Tensor: return -(self>x)
def __ne__(self, x) -> Tensor: return (self<x) + (self>x) # type: ignore[override]
def __eq__(self, x) -> Tensor: return 1.0-(self != x) # type: ignore[override]
def __eq__(self, x) -> Tensor: return -(self != x) # type: ignore[override]
# ***** functional nn ops *****