just cmplt (#1493)

* just cmplt

* fix maximum

* don't save, there's no backward

* ugh, no slot either

* eq is a scam
This commit is contained in:
George Hotz
2023-08-08 13:58:10 -07:00
committed by GitHub
parent e2cf0f322e
commit d24f936501
16 changed files with 26 additions and 42 deletions

View File

@@ -598,9 +598,8 @@ class Tensor:
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x)
def maximum(self, x:Union[Tensor, float]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
def eq(self, x) -> Tensor: return self._broadcasted(mlops.Equal, x, False)
# ***** broadcasted trinary mlops *****
@@ -651,12 +650,12 @@ class Tensor:
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
def __ge__(self, x) -> Tensor: return self.maximum(x).eq(self)
def __le__(self, x) -> Tensor: return self.maximum(x).eq(x)
def __lt__(self, x) -> Tensor: return 1.0-(self>=x)
def __gt__(self, x) -> Tensor: return 1.0-(self<=x)
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore # mypy things this should be a bool
def __ne__(self, x) -> Tensor: return 1.0-self.eq(x) # type: ignore
def __lt__(self, x) -> Tensor: return self._broadcasted(mlops.Less, x, False)
def __gt__(self, x) -> Tensor: return self._broadcasted(mlops.Less, x, True)
def __ge__(self, x) -> Tensor: return 1.0-(self<x)
def __le__(self, x) -> Tensor: return 1.0-(self>x)
def __ne__(self, x) -> Tensor: return (self<x) + (self>x) # type: ignore
def __eq__(self, x) -> Tensor: return 1.0-(self != x) # type: ignore
# ***** functional nn ops *****