mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
function.Div -> function.IDiv [run_process_replay] (#6188)
float div is equivalent to mul a reciprocal
This commit is contained in:
@@ -131,14 +131,8 @@ class Mul(Function):
|
||||
return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
|
||||
self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
||||
|
||||
class Div(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y = x, y
|
||||
return x.e(BinaryOps.MUL, y.e(UnaryOps.RECIP)) if not dtypes.is_int(x.dtype) else x.e(BinaryOps.IDIV, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output.e(BinaryOps.MUL, self.y.e(UnaryOps.RECIP)) if self.needs_input_grad[0] else None, \
|
||||
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.MUL, self.y.e(BinaryOps.MUL, self.y).e(UnaryOps.RECIP)) if self.needs_input_grad[1] else None # noqa: E501
|
||||
class IDiv(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.IDIV, y)
|
||||
|
||||
# ************* ternary ops *************
|
||||
|
||||
|
||||
@@ -2631,7 +2631,7 @@ class Tensor:
|
||||
"""
|
||||
numerator, denominator = self._broadcasted(x, reverse)
|
||||
if upcast: numerator, denominator = numerator.cast(least_upper_float(numerator.dtype)), denominator.cast(least_upper_float(denominator.dtype))
|
||||
return F.Div.apply(numerator, denominator)
|
||||
return (numerator * denominator.reciprocal()) if dtypes.is_float(numerator.dtype) else F.IDiv.apply(numerator, denominator)
|
||||
|
||||
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user