function.Div -> function.IDiv [run_process_replay] (#6188)

float div is equivalent to mul a reciprocal
This commit is contained in:
chenyu
2024-08-19 14:59:41 -04:00
committed by GitHub
parent ee5fe12630
commit 705b8066ab
2 changed files with 3 additions and 9 deletions

View File

@@ -131,14 +131,8 @@ class Mul(Function):
return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
class Div(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y = x, y
return x.e(BinaryOps.MUL, y.e(UnaryOps.RECIP)) if not dtypes.is_int(x.dtype) else x.e(BinaryOps.IDIV, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output.e(BinaryOps.MUL, self.y.e(UnaryOps.RECIP)) if self.needs_input_grad[0] else None, \
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.MUL, self.y.e(BinaryOps.MUL, self.y).e(UnaryOps.RECIP)) if self.needs_input_grad[1] else None # noqa: E501
class IDiv(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.IDIV, y)
# ************* ternary ops *************

View File

@@ -2631,7 +2631,7 @@ class Tensor:
"""
numerator, denominator = self._broadcasted(x, reverse)
if upcast: numerator, denominator = numerator.cast(least_upper_float(numerator.dtype)), denominator.cast(least_upper_float(denominator.dtype))
return F.Div.apply(numerator, denominator)
return (numerator * denominator.reciprocal()) if dtypes.is_float(numerator.dtype) else F.IDiv.apply(numerator, denominator)
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""