mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
just cmplt (#1493)
* just cmplt * fix maximum * don't save, there's no backward * ugh, no slot either * eq is a scam
This commit is contained in:
@@ -598,9 +598,8 @@ class Tensor:
|
||||
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
|
||||
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
|
||||
|
||||
def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x)
|
||||
def maximum(self, x:Union[Tensor, float]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
|
||||
def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
|
||||
def eq(self, x) -> Tensor: return self._broadcasted(mlops.Equal, x, False)
|
||||
|
||||
# ***** broadcasted trinary mlops *****
|
||||
|
||||
@@ -651,12 +650,12 @@ class Tensor:
|
||||
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
||||
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
||||
|
||||
def __ge__(self, x) -> Tensor: return self.maximum(x).eq(self)
|
||||
def __le__(self, x) -> Tensor: return self.maximum(x).eq(x)
|
||||
def __lt__(self, x) -> Tensor: return 1.0-(self>=x)
|
||||
def __gt__(self, x) -> Tensor: return 1.0-(self<=x)
|
||||
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore # mypy things this should be a bool
|
||||
def __ne__(self, x) -> Tensor: return 1.0-self.eq(x) # type: ignore
|
||||
def __lt__(self, x) -> Tensor: return self._broadcasted(mlops.Less, x, False)
|
||||
def __gt__(self, x) -> Tensor: return self._broadcasted(mlops.Less, x, True)
|
||||
def __ge__(self, x) -> Tensor: return 1.0-(self<x)
|
||||
def __le__(self, x) -> Tensor: return 1.0-(self>x)
|
||||
def __ne__(self, x) -> Tensor: return (self<x) + (self>x) # type: ignore
|
||||
def __eq__(self, x) -> Tensor: return 1.0-(self != x) # type: ignore
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user