line count

This commit is contained in:
George Hotz
2022-08-30 15:23:35 -07:00
parent 33ac355bcd
commit db56297011

View File

@@ -94,9 +94,8 @@ class Mul(Function):
return x.binary_op(BinaryOps.MUL, y)
def backward(self, grad_output):
grad_x = self.saved_tensors[1].binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None
grad_y = self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
return grad_x, grad_y
return self.saved_tensors[1].binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
class Pow(Function):
def forward(self, x, y):
@@ -106,14 +105,10 @@ class Pow(Function):
def backward(self, grad_output):
x,y,powxy = self.saved_tensors
grad_x, grad_y = None, None
if self.needs_input_grad[0]:
tmp = y.binary_op(BinaryOps.MUL, powxy.binary_op(BinaryOps.DIV, x)) # y * (pow(x,y)/x)
grad_x = grad_output.binary_op(BinaryOps.MUL, tmp)
if self.needs_input_grad[1]:
tmp = x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, powxy) # log(x) * pow(x,y)
grad_y = grad_output.binary_op(BinaryOps.MUL, tmp)
return grad_x, grad_y
# grad_x = grad_output * y * (pow(x,y)/x)
# grad_y = grad_output * log(x) * pow(x,y)
return grad_output.binary_op(BinaryOps.MUL, y.binary_op(BinaryOps.MUL, powxy.binary_op(BinaryOps.DIV, x))) if self.needs_input_grad[0] else None, \
grad_output.binary_op(BinaryOps.MUL, x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, powxy)) if self.needs_input_grad[1] else None
# ************* movement ops *************