mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cleanups, remove np
This commit is contained in:
@@ -44,7 +44,7 @@ class Log(Function):
|
||||
|
||||
class Exp(Function):
|
||||
def forward(ctx, input):
|
||||
ret = np.clip(input, -88, 88).exp()
|
||||
ret = input.clip(-88, 88).exp()
|
||||
ctx.save_for_backward(ret)
|
||||
return ret
|
||||
|
||||
@@ -52,7 +52,7 @@ class Exp(Function):
|
||||
ret, = ctx.saved_tensors
|
||||
return grad_output * ret
|
||||
|
||||
# ************* reduce ops *************
|
||||
# ************* reduce ops (with keepdims=True) *************
|
||||
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis):
|
||||
@@ -60,7 +60,8 @@ class Sum(Function):
|
||||
return input.sum(axis, keepdims=True)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.expand(ctx.saved_tensors[0])
|
||||
shape_input, = ctx.saved_tensors
|
||||
return grad_output.expand(shape_input)
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, inp, axis):
|
||||
@@ -74,7 +75,7 @@ class Max(Function):
|
||||
div = ret2.sum(axis=tuple(axis), keepdims=True)
|
||||
return ret2*grad_output/div.type(input.dtype)
|
||||
|
||||
# ************* binary ops *************
|
||||
# ************* binary ops (with broadcasting) *************
|
||||
|
||||
def unbroadcast(out, in_sh):
|
||||
# adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i]
|
||||
|
||||
Reference in New Issue
Block a user