Unpad2D on GPU:

This commit is contained in:
George Hotz
2020-12-29 13:16:14 -05:00
parent 02655c07d5
commit 837aaacfbf
4 changed files with 32 additions and 24 deletions

View File

@@ -113,7 +113,7 @@ Relu, Log, Exp # unary ops
Sum, Max # reduce ops (with axis argument)
Dot, Conv2D # matrix multiplication and conv
Reshape, Transpose # moving things around ops
Unpad2D, Pad2D # stupid slices
Pad2D, Unpad2D # stupid slices
```
## ImageNet inference

View File

@@ -122,7 +122,6 @@ class Unpad2D(Function):
@staticmethod
def forward(ctx, x, padding=None):
return Pad2D.backward(ctx, x)
@staticmethod
def backward(ctx, grad_output):
return Pad2D.forward(ctx, grad_output)

View File

@@ -281,47 +281,56 @@ register('dot', Dot, device=Device.GPU)
# ************* simple ops *************
def get_pad2d_kernel(ctx):
return clbuild(ctx.cl_ctx, "pad2d", """
__kernel void pad2d(__global const float *input, __global float *output,
int ipx, int ipy, int py, int px, int oy, int ox, int iy, int ix) {
int BC = get_global_id(0);
int Y = get_global_id(1);
int X = get_global_id(2);
int iptr = BC*iy*ix + (Y+ipy)*ix + ipx + X;
int optr = BC*oy*ox + (Y+py)*ox + px + X;
output[optr] = input[iptr];
}""")
class Pad2D(Function):
@staticmethod
def forward(ctx, x, padding=None):
bs,cin,iy,ix = x.shape
oy,ox = iy+padding[2]+padding[3], ix+padding[0]+padding[1]
oy,ox = iy+ctx.padding[2]+ctx.padding[3], ix+ctx.padding[0]+ctx.padding[1]
ret = buffer_new(ctx, (bs, cin, oy, ox), zero=True)
pad2d = clbuild(ctx.cl_ctx, "pad2d", """
__kernel void pad2d(__global const float *input, __global float *output,
int ipx, int ipy, int py, int px, int oy, int ox, int iy, int ix) {
int BC = get_global_id(0);
int Y = get_global_id(1);
int X = get_global_id(2);
int iptr = BC*iy*ix + (Y+ipy)*ix + ipx + X;
int optr = BC*oy*ox + (Y+py)*ox + px + X;
output[optr] = input[iptr];
}""")
ctx.save_for_backward(padding, pad2d)
pad2d(ctx.cl_queue, [bs*cin, iy, ix], None,
get_pad2d_kernel(ctx)(ctx.cl_queue, [bs*cin, iy, ix], None,
x.cl, ret.cl,
i32(0), i32(0), i32(padding[2]), i32(padding[0]),
i32(0), i32(0), i32(ctx.padding[2]), i32(ctx.padding[0]),
i32(oy), i32(ox), i32(iy), i32(ix)
)
return ret
@staticmethod
def backward(ctx, grad_output):
padding, pad2d = ctx.saved_tensors
bs, cin, iy, ix = grad_output.shape
oy, ox = iy - padding[2] - padding[3], ix - padding[0] - padding[1]
oy, ox = iy - ctx.padding[2] - ctx.padding[3], ix - ctx.padding[0] - ctx.padding[1]
ret = buffer_new(ctx, (bs, cin, oy, ox))
pad2d(ctx.cl_queue, [bs*cin, oy, ox], None,
get_pad2d_kernel(ctx)(ctx.cl_queue, [bs*cin, oy, ox], None,
grad_output.cl, ret.cl,
i32(padding[2]), i32(padding[0]), i32(0), i32(0),
i32(ctx.padding[2]), i32(ctx.padding[0]), i32(0), i32(0),
i32(oy), i32(ox), i32(iy), i32(ix)
)
return ret
register('pad2d', Pad2D, device=Device.GPU)
# TODO: this is an exact copy from the CPU code
class Unpad2D(Function):
@staticmethod
def forward(ctx, x, padding=None):
return Pad2D.backward(ctx, x)
@staticmethod
def backward(ctx, grad_output):
return Pad2D.forward(ctx, grad_output)
register('unpad2d', Unpad2D, device=Device.GPU)
class Reshape(Function):
@staticmethod
def forward(ctx, x, shape):

View File

@@ -291,7 +291,7 @@ def register(name, fxn, device=Device.CPU):
def dispatch(*x, **kwargs):
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
f = (Tensor.ops[tt.device])[name]
f = Tensor.ops[tt.device][name]
f.cl_ctx, f.cl_queue, f.ane, f.device = cl_ctx, cl_queue, ane, tt.device
return f.apply(f, *x, **kwargs)
setattr(Tensor, name, dispatch)