From 92abe436838fa09f29de1bc2faf76a7189415f18 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 31 Dec 2020 09:49:52 -0500 Subject: [PATCH] reduce before binary because of unbroadcasting --- README.md | 2 +- tinygrad/ops_cpu.py | 66 +++++++++++++++++++-------------------- tinygrad/ops_gpu.py | 76 ++++++++++++++++++++++----------------------- tinygrad/tensor.py | 3 +- 4 files changed, 73 insertions(+), 74 deletions(-) diff --git a/README.md b/README.md index b7a4187244..e7b333c54f 100644 --- a/README.md +++ b/README.md @@ -109,8 +109,8 @@ You need to support 14 first class ops: ``` Relu, Log, Exp # unary ops -Add, Sub, Mul, Pow # binary ops (with broadcasting) Sum, Max # reduce ops (with axis argument) +Add, Sub, Mul, Pow # binary ops (with broadcasting) Reshape, Transpose, Slice # movement ops Matmul, Conv2D # processing ops ``` diff --git a/tinygrad/ops_cpu.py b/tinygrad/ops_cpu.py index f89784e646..5cd2e2daf4 100644 --- a/tinygrad/ops_cpu.py +++ b/tinygrad/ops_cpu.py @@ -39,6 +39,39 @@ class Exp(Function): ret, = ctx.saved_tensors return grad_output * ret +# ************* reduce ops ************* + +class Sum(Function): + @staticmethod + def forward(ctx, input, axis=None): + ctx.save_for_backward(input, axis) + return np.array([input.sum()]) if axis is None else input.sum(axis=axis) + + @staticmethod + def backward(ctx, grad_output): + input, axis = ctx.saved_tensors + axis = [axis] if type(axis) is int else axis + shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] + return grad_output.reshape(shape) + np.zeros_like(input) + +class Max(Function): + @staticmethod + def forward(ctx, inp, axis=None): + axis = [axis] if type(axis) == int else axis + ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True) + ctx.save_for_backward(inp, axis, ret) + if axis is not None: + ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis]) + return ret + + @staticmethod + def backward(ctx, grad_output): + input, axis, ret = ctx.saved_tensors + shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] + ret2 = (input==ret.reshape(shape)) + div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True) + return ret2*grad_output.reshape(shape)/div + # ************* binary ops ************* def unbroadcast(out, in_sh): @@ -91,39 +124,6 @@ class Pow(Function): return unbroadcast(y * (x**(y-1.0)) * grad_output, x.shape), \ unbroadcast((x**y) * np.log(x) * grad_output, y.shape) -# ************* reduce ops ************* - -class Sum(Function): - @staticmethod - def forward(ctx, input, axis=None): - ctx.save_for_backward(input, axis) - return np.array([input.sum()]) if axis is None else input.sum(axis=axis) - - @staticmethod - def backward(ctx, grad_output): - input, axis = ctx.saved_tensors - axis = [axis] if type(axis) is int else axis - shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] - return grad_output.reshape(shape) + np.zeros_like(input) - -class Max(Function): - @staticmethod - def forward(ctx, inp, axis=None): - axis = [axis] if type(axis) == int else axis - ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True) - ctx.save_for_backward(inp, axis, ret) - if axis is not None: - ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis]) - return ret - - @staticmethod - def backward(ctx, grad_output): - input, axis, ret = ctx.saved_tensors - shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] - ret2 = (input==ret.reshape(shape)) - div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True) - return ret2*grad_output.reshape(shape)/div - # ************* movement ops ************* def inner_slice(x, arg): diff --git a/tinygrad/ops_gpu.py b/tinygrad/ops_gpu.py index 84766cc15f..022e7dfa4b 100644 --- a/tinygrad/ops_gpu.py +++ b/tinygrad/ops_gpu.py @@ -178,6 +178,44 @@ class Exp(Function): ret, = ctx.saved_tensors return binary_op(ctx, 'a * b', grad_output, ret) +# ************* reduce ops ************* + +class Sum(Function): + @staticmethod + def forward(ctx, input, axis=None): + axis = [axis] if type(axis) == int else axis + ctx.save_for_backward(input, axis) + ret = reduce_op(ctx, "out += a", "out", input, axis=axis) + if axis is not None: + ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis]) + return ret + + @staticmethod + def backward(ctx, grad_output): + input, axis = ctx.saved_tensors + shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] + output = GPUBuffer(shape, hostbuf=grad_output) + return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True)) + +class Max(Function): + @staticmethod + def forward(ctx, input, axis=None): + axis = [axis] if type(axis) == int else axis + ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis) + ctx.save_for_backward(input, axis, ret) + if axis is not None: + ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis]) + return ret + + @staticmethod + def backward(ctx, grad_output): + input, axis, ret = ctx.saved_tensors + shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] + ret2 = binary_op(ctx, "1.0*(a==b)", input, GPUBuffer(shape, ret)) + div = reduce_op(ctx, "out += a", "out+1e-10", ret2, axis=axis) + ret3 = binary_op(ctx, "a/b", ret2, GPUBuffer(shape, div)) + return binary_op(ctx, 'a*b', ret3, GPUBuffer(shape, grad_output)) + # ************* binary ops ************* def unbroadcast(ctx, out, in_sh): @@ -236,44 +274,6 @@ class Pow(Function): binary_op(ctx, 'pow(a, (float)b) * log(a);', x, y)) return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape), -# ************* reduce ops ************* - -class Sum(Function): - @staticmethod - def forward(ctx, input, axis=None): - axis = [axis] if type(axis) == int else axis - ctx.save_for_backward(input, axis) - ret = reduce_op(ctx, "out += a", "out", input, axis=axis) - if axis is not None: - ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis]) - return ret - - @staticmethod - def backward(ctx, grad_output): - input, axis = ctx.saved_tensors - shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] - output = GPUBuffer(shape, hostbuf=grad_output) - return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True)) - -class Max(Function): - @staticmethod - def forward(ctx, input, axis=None): - axis = [axis] if type(axis) == int else axis - ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis) - ctx.save_for_backward(input, axis, ret) - if axis is not None: - ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis]) - return ret - - @staticmethod - def backward(ctx, grad_output): - input, axis, ret = ctx.saved_tensors - shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] - ret2 = binary_op(ctx, "1.0*(a==b)", input, GPUBuffer(shape, ret)) - div = reduce_op(ctx, "out += a", "out+1e-10", ret2, axis=axis) - ret3 = binary_op(ctx, "a/b", ret2, GPUBuffer(shape, div)) - return binary_op(ctx, 'a*b', ret3, GPUBuffer(shape, grad_output)) - # ************* movement ops ************* def inner_slice(ctx, x, arg): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 6dd73a3dd9..0631381699 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -265,8 +265,7 @@ class Tensor: return self._pool2d(*kernel_size).mean(axis=(3,5)) def max_pool2d(self, kernel_size=(2,2)): - # TODO: support tuples in max and avoid a copy - return self._pool2d(*kernel_size).max(axis=5).max(axis=3) + return self._pool2d(*kernel_size).max(axis=(3,5)) # An instantiation of the Function is the Context class Function: