mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
reduce before binary because of unbroadcasting
This commit is contained in:
@@ -109,8 +109,8 @@ You need to support 14 first class ops:
|
||||
|
||||
```
|
||||
Relu, Log, Exp # unary ops
|
||||
Add, Sub, Mul, Pow # binary ops (with broadcasting)
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Add, Sub, Mul, Pow # binary ops (with broadcasting)
|
||||
Reshape, Transpose, Slice # movement ops
|
||||
Matmul, Conv2D # processing ops
|
||||
```
|
||||
|
||||
@@ -39,6 +39,39 @@ class Exp(Function):
|
||||
ret, = ctx.saved_tensors
|
||||
return grad_output * ret
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.save_for_backward(input, axis)
|
||||
return np.array([input.sum()]) if axis is None else input.sum(axis=axis)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
axis = [axis] if type(axis) is int else axis
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
return grad_output.reshape(shape) + np.zeros_like(input)
|
||||
|
||||
class Max(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inp, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
ctx.save_for_backward(inp, axis, ret)
|
||||
if axis is not None:
|
||||
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
ret2 = (input==ret.reshape(shape))
|
||||
div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
return ret2*grad_output.reshape(shape)/div
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
def unbroadcast(out, in_sh):
|
||||
@@ -91,39 +124,6 @@ class Pow(Function):
|
||||
return unbroadcast(y * (x**(y-1.0)) * grad_output, x.shape), \
|
||||
unbroadcast((x**y) * np.log(x) * grad_output, y.shape)
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.save_for_backward(input, axis)
|
||||
return np.array([input.sum()]) if axis is None else input.sum(axis=axis)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
axis = [axis] if type(axis) is int else axis
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
return grad_output.reshape(shape) + np.zeros_like(input)
|
||||
|
||||
class Max(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inp, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
ctx.save_for_backward(inp, axis, ret)
|
||||
if axis is not None:
|
||||
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
ret2 = (input==ret.reshape(shape))
|
||||
div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
return ret2*grad_output.reshape(shape)/div
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
def inner_slice(x, arg):
|
||||
|
||||
@@ -178,6 +178,44 @@ class Exp(Function):
|
||||
ret, = ctx.saved_tensors
|
||||
return binary_op(ctx, 'a * b', grad_output, ret)
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ctx.save_for_backward(input, axis)
|
||||
ret = reduce_op(ctx, "out += a", "out", input, axis=axis)
|
||||
if axis is not None:
|
||||
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
output = GPUBuffer(shape, hostbuf=grad_output)
|
||||
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True))
|
||||
|
||||
class Max(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis)
|
||||
ctx.save_for_backward(input, axis, ret)
|
||||
if axis is not None:
|
||||
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
ret2 = binary_op(ctx, "1.0*(a==b)", input, GPUBuffer(shape, ret))
|
||||
div = reduce_op(ctx, "out += a", "out+1e-10", ret2, axis=axis)
|
||||
ret3 = binary_op(ctx, "a/b", ret2, GPUBuffer(shape, div))
|
||||
return binary_op(ctx, 'a*b', ret3, GPUBuffer(shape, grad_output))
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
def unbroadcast(ctx, out, in_sh):
|
||||
@@ -236,44 +274,6 @@ class Pow(Function):
|
||||
binary_op(ctx, 'pow(a, (float)b) * log(a);', x, y))
|
||||
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ctx.save_for_backward(input, axis)
|
||||
ret = reduce_op(ctx, "out += a", "out", input, axis=axis)
|
||||
if axis is not None:
|
||||
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
output = GPUBuffer(shape, hostbuf=grad_output)
|
||||
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True))
|
||||
|
||||
class Max(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis)
|
||||
ctx.save_for_backward(input, axis, ret)
|
||||
if axis is not None:
|
||||
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
ret2 = binary_op(ctx, "1.0*(a==b)", input, GPUBuffer(shape, ret))
|
||||
div = reduce_op(ctx, "out += a", "out+1e-10", ret2, axis=axis)
|
||||
ret3 = binary_op(ctx, "a/b", ret2, GPUBuffer(shape, div))
|
||||
return binary_op(ctx, 'a*b', ret3, GPUBuffer(shape, grad_output))
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
def inner_slice(ctx, x, arg):
|
||||
|
||||
@@ -265,8 +265,7 @@ class Tensor:
|
||||
return self._pool2d(*kernel_size).mean(axis=(3,5))
|
||||
|
||||
def max_pool2d(self, kernel_size=(2,2)):
|
||||
# TODO: support tuples in max and avoid a copy
|
||||
return self._pool2d(*kernel_size).max(axis=5).max(axis=3)
|
||||
return self._pool2d(*kernel_size).max(axis=(3,5))
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
|
||||
Reference in New Issue
Block a user