do replace div->mul for non-floats (#1644)

This commit is contained in:
nimlgen
2023-08-23 17:34:31 +03:00
committed by GitHub
parent da694d4241
commit a65ae1198b
2 changed files with 3 additions and 1 deletions

View File

@@ -588,7 +588,7 @@ class Tensor:
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x or reverse else self
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x else self.mul(1/x)
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x)
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
if x.__class__ is not Tensor and not reverse:
# simple pow identities