mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
line count
This commit is contained in:
@@ -94,9 +94,8 @@ class Mul(Function):
|
||||
return x.binary_op(BinaryOps.MUL, y)
|
||||
|
||||
def backward(self, grad_output):
|
||||
grad_x = self.saved_tensors[1].binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None
|
||||
grad_y = self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
||||
return grad_x, grad_y
|
||||
return self.saved_tensors[1].binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
|
||||
self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
||||
|
||||
class Pow(Function):
|
||||
def forward(self, x, y):
|
||||
@@ -106,14 +105,10 @@ class Pow(Function):
|
||||
|
||||
def backward(self, grad_output):
|
||||
x,y,powxy = self.saved_tensors
|
||||
grad_x, grad_y = None, None
|
||||
if self.needs_input_grad[0]:
|
||||
tmp = y.binary_op(BinaryOps.MUL, powxy.binary_op(BinaryOps.DIV, x)) # y * (pow(x,y)/x)
|
||||
grad_x = grad_output.binary_op(BinaryOps.MUL, tmp)
|
||||
if self.needs_input_grad[1]:
|
||||
tmp = x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, powxy) # log(x) * pow(x,y)
|
||||
grad_y = grad_output.binary_op(BinaryOps.MUL, tmp)
|
||||
return grad_x, grad_y
|
||||
# grad_x = grad_output * y * (pow(x,y)/x)
|
||||
# grad_y = grad_output * log(x) * pow(x,y)
|
||||
return grad_output.binary_op(BinaryOps.MUL, y.binary_op(BinaryOps.MUL, powxy.binary_op(BinaryOps.DIV, x))) if self.needs_input_grad[0] else None, \
|
||||
grad_output.binary_op(BinaryOps.MUL, x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, powxy)) if self.needs_input_grad[1] else None
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
|
||||
Reference in New Issue
Block a user