less lines (#197)

This commit is contained in:
James Roberts
2020-12-14 23:53:00 +02:00
committed by GitHub
parent b86bbd2e72
commit 78210b5e40
3 changed files with 23 additions and 38 deletions

View File

@@ -93,9 +93,7 @@ class Pad2D(Function):
@staticmethod
def forward(ctx, x, padding=None):
ctx.save_for_backward(padding)
return np.pad(x,
((0,0), (0,0),
(padding[2], padding[3]), (padding[0], padding[1])))
return np.pad(x, ((0,0), (0,0), tuple(padding[2:4]), tuple(padding[0:2])))
@staticmethod
def backward(ctx, grad_output):
@@ -127,8 +125,7 @@ class ReLU(Function):
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output * (input >= 0)
return grad_input
return grad_output * (input >= 0)
register('relu', ReLU)
class Sigmoid(Function):
@@ -146,8 +143,7 @@ class Sigmoid(Function):
@staticmethod
def backward(ctx, grad_output):
ret, = ctx.saved_tensors
grad_input = grad_output * (ret * (1 - ret))
return grad_input
return grad_output * (ret * (1 - ret))
register('sigmoid', Sigmoid)
class LogSoftmax(Function):
@@ -185,12 +181,10 @@ class Conv2D(Function):
gx = x.reshape(bs,ctx.groups,cin,x.shape[2],x.shape[3])
tx = np.lib.stride_tricks.as_strided(gx,
shape=(bs, ctx.groups, cin, oy, ox, H, W),
strides=(gx.strides[0], gx.strides[1], gx.strides[2],
gx.strides[3]*ys, gx.strides[4]*xs,
gx.strides[3], gx.strides[4]),
writeable=False,
)
shape=(bs, ctx.groups, cin, oy, ox, H, W),
strides=(*gx.strides[0:3], gx.strides[3]*ys, gx.strides[4]*xs, *gx.strides[3:5]),
writeable=False,
)
tw = w.reshape(ctx.groups, rcout, cin, H, W)
ctx.save_for_backward(tx, tw, x.shape)
@@ -258,9 +252,7 @@ class MaxPool2D(Function):
@staticmethod
def backward(ctx, grad_output):
idxs,s = ctx.saved_tensors
return unstack_for_pool(
lambda idx: grad_output * (idxs == idx),
s, *ctx.kernel_size)
return unstack_for_pool(lambda idx: grad_output * (idxs == idx), s, *ctx.kernel_size)
register('max_pool2d', MaxPool2D)
class AvgPool2D(Function):
@@ -274,8 +266,6 @@ class AvgPool2D(Function):
def backward(ctx, grad_output):
s, = ctx.saved_tensors
py, px = ctx.kernel_size
return unstack_for_pool(
lambda idx: grad_output/py/px,
s, py, px)
return unstack_for_pool(lambda idx: grad_output/py/px, s, py, px)
register('avg_pool2d', AvgPool2D)

View File

@@ -346,8 +346,7 @@ class Reshape(Function):
@staticmethod
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
grad_output = GPUBuffer(in_shape, hostbuf=grad_output)
return grad_output
return GPUBuffer(in_shape, hostbuf=grad_output)
register('reshape', Reshape, device=Tensor.GPU)
# ************* activation ops *************
@@ -449,6 +448,10 @@ class Conv2D(Function):
# output buffer
ret = buffer_new(ctx, (bs, cout, oy, ox))
# input = (bs, groups, cin, iy, ix)
# weight = (groups, rcout, cin, H, W)
# output = (bs, groups, rcout, oy, ox)
conv = clbuild(ctx.cl_ctx, "conv", """
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs) {
@@ -462,9 +465,6 @@ class Conv2D(Function):
int IY = Y*ys;
int IX = X*xs;
// input = (bs, groups, cin, iy, ix)
// weight = (groups, rcout, cin, H, W)
// output = (bs, groups, rcout, oy, ox)
float acc = 0.0;
for (int ci = 0; ci < cin; ci++) {
for (int y = IY; y < IY+H; y++) {

View File

@@ -8,9 +8,8 @@ from collections import defaultdict
DEBUG = os.getenv("DEBUG", None) is not None
if DEBUG:
import collections, atexit, time
debug_counts = collections.defaultdict(int)
debug_times = collections.defaultdict(float)
import atexit, time
debug_counts, debug_times = defaultdict(int), defaultdict(float)
def print_debug_exit():
for name, _ in sorted(debug_times.items(), key=lambda x: -x[1]):
print(f"{name:>20} : {debug_counts[name]:>6} {debug_times[name]:>10.2f} ms")
@@ -88,9 +87,7 @@ class Tensor:
Tensor.did_float_warning = True
self.device = Tensor.CPU
self.data = data
self.grad = None
self.requires_grad = requires_grad
self.data, self.grad, self.requires_grad = data, None, requires_grad
if gpu:
self.cuda_()
@@ -157,12 +154,11 @@ class Tensor:
if len(t0._ctx.parents) == 1:
grads = [grads]
for t,g in zip(t0._ctx.parents, grads):
if g is None:
continue
assert g.shape == t.shape, \
f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
gt = Tensor(g, requires_grad=False)
t.grad = gt if t.grad is None else (t.grad + gt)
if g is not None:
assert g.shape == t.shape, \
f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
gt = Tensor(g, requires_grad=False)
t.grad = gt if t.grad is None else (t.grad + gt)
# ***** tinygrad supports CPU and GPU *****
@@ -197,8 +193,7 @@ class Tensor:
if self.grad:
ret.grad = self.grad.cuda()
return ret
else:
return self
return self
def ane(self):
assert(not self.gpu)