mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
saving 50 LOC with automatic @staticmethod for forward and backward (#252)
* automatic @staticmethod for forward and backward * triggering unit tests
This commit is contained in:
@@ -28,7 +28,6 @@ def compile_relu(ane, sz):
|
||||
return compile_wrapper(ane, bytes(dat))
|
||||
|
||||
class ReLU(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ret = ctx.ane.tensor(input.shape)
|
||||
ctx.ane.run(compile_relu(ctx.ane, input.sz), input, ret)
|
||||
|
||||
@@ -4,35 +4,29 @@ from .tensor import Function, register
|
||||
# ************* unary ops *************
|
||||
|
||||
class ReLU(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return np.maximum(input, 0)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output * (input >= 0)
|
||||
|
||||
class Log(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return np.log(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output / input
|
||||
|
||||
class Exp(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ret = np.exp(input)
|
||||
ctx.save_for_backward(ret)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
ret, = ctx.saved_tensors
|
||||
return grad_output * ret
|
||||
@@ -40,12 +34,10 @@ class Exp(Function):
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.save_for_backward(input, axis)
|
||||
return np.array([input.sum()]) if axis is None else input.sum(axis=axis)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
axis = [axis] if type(axis) is int else axis
|
||||
@@ -53,7 +45,6 @@ class Sum(Function):
|
||||
return grad_output.reshape(shape) + np.zeros_like(input)
|
||||
|
||||
class Max(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inp, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
@@ -62,7 +53,6 @@ class Max(Function):
|
||||
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
@@ -78,45 +68,37 @@ def unbroadcast(out, in_sh):
|
||||
return out.sum(axis=sum_axis).reshape(in_sh)
|
||||
|
||||
class Add(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return x+y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(grad_output, shape_x), unbroadcast(grad_output, shape_y)
|
||||
|
||||
class Sub(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return x-y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(grad_output, shape_x), unbroadcast(-grad_output, shape_y)
|
||||
|
||||
class Mul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return x*y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
return unbroadcast(y*grad_output, x.shape), unbroadcast(x*grad_output, y.shape)
|
||||
|
||||
class Pow(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return x ** y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
return unbroadcast(y * (x**(y-1.0)) * grad_output, x.shape), \
|
||||
@@ -125,23 +107,19 @@ class Pow(Function):
|
||||
# ************* movement ops *************
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return x.reshape(shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return grad_output.reshape(in_shape)
|
||||
|
||||
class Transpose(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, order):
|
||||
ctx.save_for_backward(order)
|
||||
return np.transpose(x, order)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x):
|
||||
return np.transpose(x, np.argsort(ctx.order))
|
||||
|
||||
@@ -152,12 +130,10 @@ def inner_slice(x, arg):
|
||||
return x[tuple([slice(x[0], x[1], None) for x in slicee])]
|
||||
|
||||
class Slice(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return inner_slice(x, arg)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shape, = ctx.saved_tensors
|
||||
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
|
||||
@@ -166,12 +142,10 @@ class Slice(Function):
|
||||
# ************* processing ops *************
|
||||
|
||||
class Matmul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
ctx.save_for_backward(input, weight)
|
||||
return input @ weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = grad_output @ np.swapaxes(weight, -2, -1)
|
||||
@@ -179,7 +153,6 @@ class Matmul(Function):
|
||||
return grad_input, grad_weight
|
||||
|
||||
class Conv2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w, stride=1, groups=1):
|
||||
if type(ctx.stride) == int:
|
||||
ctx.stride = (ctx.stride, ctx.stride)
|
||||
@@ -206,7 +179,6 @@ class Conv2D(Function):
|
||||
ret[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
|
||||
return np.moveaxis(ret,4,2).reshape(bs, cout, oy, ox)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
bs,_,oy,ox = grad_output.shape
|
||||
tx, tw, x_shape = ctx.saved_tensors
|
||||
|
||||
@@ -31,35 +31,29 @@ def unary_op(ctx, code, x):
|
||||
return ret
|
||||
|
||||
class ReLU(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return unary_op(ctx, 'max(a, (float)0.)', input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return binary_op(ctx, 'a * (b >= 0)', grad_output, input)
|
||||
|
||||
class Log(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return unary_op(ctx, 'log(a)', input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return binary_op(ctx, 'a / b', grad_output, input)
|
||||
|
||||
class Exp(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ret = unary_op(ctx, 'exp(a)', input)
|
||||
ctx.save_for_backward(ret)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
ret, = ctx.saved_tensors
|
||||
return binary_op(ctx, 'a * b', grad_output, ret)
|
||||
@@ -111,7 +105,6 @@ def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"):
|
||||
return ret
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ctx.save_for_backward(input, axis)
|
||||
@@ -120,7 +113,6 @@ class Sum(Function):
|
||||
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
@@ -128,7 +120,6 @@ class Sum(Function):
|
||||
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True))
|
||||
|
||||
class Max(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis, start="-INFINITY")
|
||||
@@ -137,7 +128,6 @@ class Max(Function):
|
||||
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
@@ -195,36 +185,30 @@ def unbroadcast(ctx, out, in_sh):
|
||||
return reduce_op(ctx, "out += a", "out", out, sum_axis)
|
||||
|
||||
class Add(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return binary_op(ctx, 'a+b', x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_x, grad_y = grad_output, grad_output
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
|
||||
|
||||
class Sub(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return binary_op(ctx, 'a-b', x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_x, grad_y = grad_output, unary_op(ctx, '-a', grad_output)
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
|
||||
|
||||
class Mul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return binary_op(ctx, 'a*b', x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
grad_x = binary_op(ctx, 'a*b', y, grad_output)
|
||||
@@ -232,12 +216,10 @@ class Mul(Function):
|
||||
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
|
||||
|
||||
class Pow(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return binary_op(ctx, 'pow(a,b)', x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
grad_x = binary_op(ctx, 'a*b', grad_output,
|
||||
@@ -249,7 +231,6 @@ class Pow(Function):
|
||||
# ************* movement ops *************
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
shape = tuple(-np.prod(x.shape) // np.prod(shape) if s == -1 else s for s in shape)
|
||||
@@ -257,7 +238,6 @@ class Reshape(Function):
|
||||
assert np.prod(x.shape) == np.prod(r.shape)
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return GPUBuffer(in_shape, hostbuf=grad_output)
|
||||
@@ -285,12 +265,10 @@ def perm_axis(ctx, inp, order):
|
||||
return ret
|
||||
|
||||
class Transpose(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, order=(1,0)):
|
||||
ctx.save_for_backward(order)
|
||||
return perm_axis(ctx, x, order)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return perm_axis(ctx, grad_output, np.argsort(ctx.order))
|
||||
|
||||
@@ -322,12 +300,10 @@ def inner_slice(ctx, x, arg):
|
||||
return ret
|
||||
|
||||
class Slice(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return inner_slice(ctx, x, arg)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shape, = ctx.saved_tensors
|
||||
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
|
||||
@@ -336,7 +312,6 @@ class Slice(Function):
|
||||
# ************* processing ops *************
|
||||
|
||||
class Matmul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
assert input.shape[-1] == weight.shape[-2]
|
||||
cnt = np.prod(input.shape[0:-2]) if len(input.shape) > 2 else 1
|
||||
@@ -369,7 +344,6 @@ class Matmul(Function):
|
||||
msize, i32(1), msize, i32(1), osize, osize)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, matmul, cnt = ctx.saved_tensors
|
||||
isize, msize, osize = i32(input.shape[-2]), i32(input.shape[-1]), i32(weight.shape[-1])
|
||||
@@ -390,7 +364,6 @@ class Matmul(Function):
|
||||
return grad_input, grad_weight
|
||||
|
||||
class Conv2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w, stride=1, groups=1):
|
||||
if type(ctx.stride) == int:
|
||||
ctx.stride = (ctx.stride, ctx.stride)
|
||||
@@ -443,7 +416,6 @@ class Conv2D(Function):
|
||||
)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
bs,_,oy,ox = grad_output.shape
|
||||
x, w = ctx.saved_tensors
|
||||
|
||||
@@ -293,6 +293,11 @@ class Tensor:
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
cls.forward = staticmethod(cls.forward)
|
||||
cls.backward = staticmethod(cls.backward)
|
||||
return super().__new__(cls) #
|
||||
|
||||
def __init__(self, *tensors):
|
||||
self.parents = tensors
|
||||
self.saved_tensors = []
|
||||
|
||||
Reference in New Issue
Block a user