reduce before binary because of unbroadcasting

This commit is contained in:
George Hotz
2020-12-31 09:49:52 -05:00
parent 4291002881
commit 92abe43683
4 changed files with 73 additions and 74 deletions

View File

@@ -109,8 +109,8 @@ You need to support 14 first class ops:
```
Relu, Log, Exp # unary ops
Add, Sub, Mul, Pow # binary ops (with broadcasting)
Sum, Max # reduce ops (with axis argument)
Add, Sub, Mul, Pow # binary ops (with broadcasting)
Reshape, Transpose, Slice # movement ops
Matmul, Conv2D # processing ops
```

View File

@@ -39,6 +39,39 @@ class Exp(Function):
ret, = ctx.saved_tensors
return grad_output * ret
# ************* reduce ops *************
class Sum(Function):
@staticmethod
def forward(ctx, input, axis=None):
ctx.save_for_backward(input, axis)
return np.array([input.sum()]) if axis is None else input.sum(axis=axis)
@staticmethod
def backward(ctx, grad_output):
input, axis = ctx.saved_tensors
axis = [axis] if type(axis) is int else axis
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
return grad_output.reshape(shape) + np.zeros_like(input)
class Max(Function):
@staticmethod
def forward(ctx, inp, axis=None):
axis = [axis] if type(axis) == int else axis
ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
ctx.save_for_backward(inp, axis, ret)
if axis is not None:
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
return ret
@staticmethod
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
ret2 = (input==ret.reshape(shape))
div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True)
return ret2*grad_output.reshape(shape)/div
# ************* binary ops *************
def unbroadcast(out, in_sh):
@@ -91,39 +124,6 @@ class Pow(Function):
return unbroadcast(y * (x**(y-1.0)) * grad_output, x.shape), \
unbroadcast((x**y) * np.log(x) * grad_output, y.shape)
# ************* reduce ops *************
class Sum(Function):
@staticmethod
def forward(ctx, input, axis=None):
ctx.save_for_backward(input, axis)
return np.array([input.sum()]) if axis is None else input.sum(axis=axis)
@staticmethod
def backward(ctx, grad_output):
input, axis = ctx.saved_tensors
axis = [axis] if type(axis) is int else axis
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
return grad_output.reshape(shape) + np.zeros_like(input)
class Max(Function):
@staticmethod
def forward(ctx, inp, axis=None):
axis = [axis] if type(axis) == int else axis
ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
ctx.save_for_backward(inp, axis, ret)
if axis is not None:
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
return ret
@staticmethod
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
ret2 = (input==ret.reshape(shape))
div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True)
return ret2*grad_output.reshape(shape)/div
# ************* movement ops *************
def inner_slice(x, arg):

View File

@@ -178,6 +178,44 @@ class Exp(Function):
ret, = ctx.saved_tensors
return binary_op(ctx, 'a * b', grad_output, ret)
# ************* reduce ops *************
class Sum(Function):
@staticmethod
def forward(ctx, input, axis=None):
axis = [axis] if type(axis) == int else axis
ctx.save_for_backward(input, axis)
ret = reduce_op(ctx, "out += a", "out", input, axis=axis)
if axis is not None:
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
return ret
@staticmethod
def backward(ctx, grad_output):
input, axis = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
output = GPUBuffer(shape, hostbuf=grad_output)
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True))
class Max(Function):
@staticmethod
def forward(ctx, input, axis=None):
axis = [axis] if type(axis) == int else axis
ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis)
ctx.save_for_backward(input, axis, ret)
if axis is not None:
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
return ret
@staticmethod
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
ret2 = binary_op(ctx, "1.0*(a==b)", input, GPUBuffer(shape, ret))
div = reduce_op(ctx, "out += a", "out+1e-10", ret2, axis=axis)
ret3 = binary_op(ctx, "a/b", ret2, GPUBuffer(shape, div))
return binary_op(ctx, 'a*b', ret3, GPUBuffer(shape, grad_output))
# ************* binary ops *************
def unbroadcast(ctx, out, in_sh):
@@ -236,44 +274,6 @@ class Pow(Function):
binary_op(ctx, 'pow(a, (float)b) * log(a);', x, y))
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
# ************* reduce ops *************
class Sum(Function):
@staticmethod
def forward(ctx, input, axis=None):
axis = [axis] if type(axis) == int else axis
ctx.save_for_backward(input, axis)
ret = reduce_op(ctx, "out += a", "out", input, axis=axis)
if axis is not None:
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
return ret
@staticmethod
def backward(ctx, grad_output):
input, axis = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
output = GPUBuffer(shape, hostbuf=grad_output)
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True))
class Max(Function):
@staticmethod
def forward(ctx, input, axis=None):
axis = [axis] if type(axis) == int else axis
ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis)
ctx.save_for_backward(input, axis, ret)
if axis is not None:
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
return ret
@staticmethod
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
ret2 = binary_op(ctx, "1.0*(a==b)", input, GPUBuffer(shape, ret))
div = reduce_op(ctx, "out += a", "out+1e-10", ret2, axis=axis)
ret3 = binary_op(ctx, "a/b", ret2, GPUBuffer(shape, div))
return binary_op(ctx, 'a*b', ret3, GPUBuffer(shape, grad_output))
# ************* movement ops *************
def inner_slice(ctx, x, arg):

View File

@@ -265,8 +265,7 @@ class Tensor:
return self._pool2d(*kernel_size).mean(axis=(3,5))
def max_pool2d(self, kernel_size=(2,2)):
# TODO: support tuples in max and avoid a copy
return self._pool2d(*kernel_size).max(axis=5).max(axis=3)
return self._pool2d(*kernel_size).max(axis=(3,5))
# An instantiation of the Function is the Context
class Function: