saving 50 LOC with automatic @staticmethod for forward and backward (#252)

* automatic @staticmethod for forward and backward

* triggering unit tests
This commit is contained in:
ziofil
2021-04-25 21:04:16 -04:00
committed by GitHub
parent f0cc2b66f8
commit 155ec1f18e
4 changed files with 5 additions and 57 deletions

View File

@@ -28,7 +28,6 @@ def compile_relu(ane, sz):
return compile_wrapper(ane, bytes(dat))
class ReLU(Function):
@staticmethod
def forward(ctx, input):
ret = ctx.ane.tensor(input.shape)
ctx.ane.run(compile_relu(ctx.ane, input.sz), input, ret)

View File

@@ -4,35 +4,29 @@ from .tensor import Function, register
# ************* unary ops *************
class ReLU(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return np.maximum(input, 0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output * (input >= 0)
class Log(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return np.log(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output / input
class Exp(Function):
@staticmethod
def forward(ctx, input):
ret = np.exp(input)
ctx.save_for_backward(ret)
return ret
@staticmethod
def backward(ctx, grad_output):
ret, = ctx.saved_tensors
return grad_output * ret
@@ -40,12 +34,10 @@ class Exp(Function):
# ************* reduce ops *************
class Sum(Function):
@staticmethod
def forward(ctx, input, axis=None):
ctx.save_for_backward(input, axis)
return np.array([input.sum()]) if axis is None else input.sum(axis=axis)
@staticmethod
def backward(ctx, grad_output):
input, axis = ctx.saved_tensors
axis = [axis] if type(axis) is int else axis
@@ -53,7 +45,6 @@ class Sum(Function):
return grad_output.reshape(shape) + np.zeros_like(input)
class Max(Function):
@staticmethod
def forward(ctx, inp, axis=None):
axis = [axis] if type(axis) == int else axis
ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
@@ -62,7 +53,6 @@ class Max(Function):
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
return ret
@staticmethod
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
@@ -78,45 +68,37 @@ def unbroadcast(out, in_sh):
return out.sum(axis=sum_axis).reshape(in_sh)
class Add(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return x+y
@staticmethod
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(grad_output, shape_x), unbroadcast(grad_output, shape_y)
class Sub(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return x-y
@staticmethod
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(grad_output, shape_x), unbroadcast(-grad_output, shape_y)
class Mul(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x*y
@staticmethod
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
return unbroadcast(y*grad_output, x.shape), unbroadcast(x*grad_output, y.shape)
class Pow(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x ** y
@staticmethod
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
return unbroadcast(y * (x**(y-1.0)) * grad_output, x.shape), \
@@ -125,23 +107,19 @@ class Pow(Function):
# ************* movement ops *************
class Reshape(Function):
@staticmethod
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
return x.reshape(shape)
@staticmethod
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return grad_output.reshape(in_shape)
class Transpose(Function):
@staticmethod
def forward(ctx, x, order):
ctx.save_for_backward(order)
return np.transpose(x, order)
@staticmethod
def backward(ctx, x):
return np.transpose(x, np.argsort(ctx.order))
@@ -152,12 +130,10 @@ def inner_slice(x, arg):
return x[tuple([slice(x[0], x[1], None) for x in slicee])]
class Slice(Function):
@staticmethod
def forward(ctx, x, arg=None):
ctx.save_for_backward(x.shape)
return inner_slice(x, arg)
@staticmethod
def backward(ctx, grad_output):
shape, = ctx.saved_tensors
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
@@ -166,12 +142,10 @@ class Slice(Function):
# ************* processing ops *************
class Matmul(Function):
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
return input @ weight
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = grad_output @ np.swapaxes(weight, -2, -1)
@@ -179,7 +153,6 @@ class Matmul(Function):
return grad_input, grad_weight
class Conv2D(Function):
@staticmethod
def forward(ctx, x, w, stride=1, groups=1):
if type(ctx.stride) == int:
ctx.stride = (ctx.stride, ctx.stride)
@@ -206,7 +179,6 @@ class Conv2D(Function):
ret[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
return np.moveaxis(ret,4,2).reshape(bs, cout, oy, ox)
@staticmethod
def backward(ctx, grad_output):
bs,_,oy,ox = grad_output.shape
tx, tw, x_shape = ctx.saved_tensors

View File

@@ -31,35 +31,29 @@ def unary_op(ctx, code, x):
return ret
class ReLU(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return unary_op(ctx, 'max(a, (float)0.)', input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return binary_op(ctx, 'a * (b >= 0)', grad_output, input)
class Log(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return unary_op(ctx, 'log(a)', input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return binary_op(ctx, 'a / b', grad_output, input)
class Exp(Function):
@staticmethod
def forward(ctx, input):
ret = unary_op(ctx, 'exp(a)', input)
ctx.save_for_backward(ret)
return ret
@staticmethod
def backward(ctx, grad_output):
ret, = ctx.saved_tensors
return binary_op(ctx, 'a * b', grad_output, ret)
@@ -111,7 +105,6 @@ def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"):
return ret
class Sum(Function):
@staticmethod
def forward(ctx, input, axis=None):
axis = [axis] if type(axis) == int else axis
ctx.save_for_backward(input, axis)
@@ -120,7 +113,6 @@ class Sum(Function):
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
return ret
@staticmethod
def backward(ctx, grad_output):
input, axis = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
@@ -128,7 +120,6 @@ class Sum(Function):
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True))
class Max(Function):
@staticmethod
def forward(ctx, input, axis=None):
axis = [axis] if type(axis) == int else axis
ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis, start="-INFINITY")
@@ -137,7 +128,6 @@ class Max(Function):
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
return ret
@staticmethod
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
@@ -195,36 +185,30 @@ def unbroadcast(ctx, out, in_sh):
return reduce_op(ctx, "out += a", "out", out, sum_axis)
class Add(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return binary_op(ctx, 'a+b', x, y)
@staticmethod
def backward(ctx, grad_output):
grad_x, grad_y = grad_output, grad_output
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
class Sub(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return binary_op(ctx, 'a-b', x, y)
@staticmethod
def backward(ctx, grad_output):
grad_x, grad_y = grad_output, unary_op(ctx, '-a', grad_output)
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
class Mul(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return binary_op(ctx, 'a*b', x, y)
@staticmethod
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
grad_x = binary_op(ctx, 'a*b', y, grad_output)
@@ -232,12 +216,10 @@ class Mul(Function):
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
class Pow(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return binary_op(ctx, 'pow(a,b)', x, y)
@staticmethod
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
grad_x = binary_op(ctx, 'a*b', grad_output,
@@ -249,7 +231,6 @@ class Pow(Function):
# ************* movement ops *************
class Reshape(Function):
@staticmethod
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
shape = tuple(-np.prod(x.shape) // np.prod(shape) if s == -1 else s for s in shape)
@@ -257,7 +238,6 @@ class Reshape(Function):
assert np.prod(x.shape) == np.prod(r.shape)
return r
@staticmethod
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return GPUBuffer(in_shape, hostbuf=grad_output)
@@ -285,12 +265,10 @@ def perm_axis(ctx, inp, order):
return ret
class Transpose(Function):
@staticmethod
def forward(ctx, x, order=(1,0)):
ctx.save_for_backward(order)
return perm_axis(ctx, x, order)
@staticmethod
def backward(ctx, grad_output):
return perm_axis(ctx, grad_output, np.argsort(ctx.order))
@@ -322,12 +300,10 @@ def inner_slice(ctx, x, arg):
return ret
class Slice(Function):
@staticmethod
def forward(ctx, x, arg=None):
ctx.save_for_backward(x.shape)
return inner_slice(ctx, x, arg)
@staticmethod
def backward(ctx, grad_output):
shape, = ctx.saved_tensors
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
@@ -336,7 +312,6 @@ class Slice(Function):
# ************* processing ops *************
class Matmul(Function):
@staticmethod
def forward(ctx, input, weight):
assert input.shape[-1] == weight.shape[-2]
cnt = np.prod(input.shape[0:-2]) if len(input.shape) > 2 else 1
@@ -369,7 +344,6 @@ class Matmul(Function):
msize, i32(1), msize, i32(1), osize, osize)
return ret
@staticmethod
def backward(ctx, grad_output):
input, weight, matmul, cnt = ctx.saved_tensors
isize, msize, osize = i32(input.shape[-2]), i32(input.shape[-1]), i32(weight.shape[-1])
@@ -390,7 +364,6 @@ class Matmul(Function):
return grad_input, grad_weight
class Conv2D(Function):
@staticmethod
def forward(ctx, x, w, stride=1, groups=1):
if type(ctx.stride) == int:
ctx.stride = (ctx.stride, ctx.stride)
@@ -443,7 +416,6 @@ class Conv2D(Function):
)
return ret
@staticmethod
def backward(ctx, grad_output):
bs,_,oy,ox = grad_output.shape
x, w = ctx.saved_tensors

View File

@@ -293,6 +293,11 @@ class Tensor:
# An instantiation of the Function is the Context
class Function:
def __new__(cls, *args, **kwargs):
cls.forward = staticmethod(cls.forward)
cls.backward = staticmethod(cls.backward)
return super().__new__(cls) #
def __init__(self, *tensors):
self.parents = tensors
self.saved_tensors = []