mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
Unpad2D on GPU:
This commit is contained in:
@@ -113,7 +113,7 @@ Relu, Log, Exp # unary ops
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Dot, Conv2D # matrix multiplication and conv
|
||||
Reshape, Transpose # moving things around ops
|
||||
Unpad2D, Pad2D # stupid slices
|
||||
Pad2D, Unpad2D # stupid slices
|
||||
```
|
||||
|
||||
## ImageNet inference
|
||||
|
||||
@@ -122,7 +122,6 @@ class Unpad2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, padding=None):
|
||||
return Pad2D.backward(ctx, x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return Pad2D.forward(ctx, grad_output)
|
||||
|
||||
@@ -281,47 +281,56 @@ register('dot', Dot, device=Device.GPU)
|
||||
|
||||
# ************* simple ops *************
|
||||
|
||||
def get_pad2d_kernel(ctx):
|
||||
return clbuild(ctx.cl_ctx, "pad2d", """
|
||||
__kernel void pad2d(__global const float *input, __global float *output,
|
||||
int ipx, int ipy, int py, int px, int oy, int ox, int iy, int ix) {
|
||||
int BC = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int X = get_global_id(2);
|
||||
|
||||
int iptr = BC*iy*ix + (Y+ipy)*ix + ipx + X;
|
||||
int optr = BC*oy*ox + (Y+py)*ox + px + X;
|
||||
|
||||
output[optr] = input[iptr];
|
||||
}""")
|
||||
|
||||
class Pad2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, padding=None):
|
||||
bs,cin,iy,ix = x.shape
|
||||
oy,ox = iy+padding[2]+padding[3], ix+padding[0]+padding[1]
|
||||
oy,ox = iy+ctx.padding[2]+ctx.padding[3], ix+ctx.padding[0]+ctx.padding[1]
|
||||
ret = buffer_new(ctx, (bs, cin, oy, ox), zero=True)
|
||||
|
||||
pad2d = clbuild(ctx.cl_ctx, "pad2d", """
|
||||
__kernel void pad2d(__global const float *input, __global float *output,
|
||||
int ipx, int ipy, int py, int px, int oy, int ox, int iy, int ix) {
|
||||
int BC = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int X = get_global_id(2);
|
||||
|
||||
int iptr = BC*iy*ix + (Y+ipy)*ix + ipx + X;
|
||||
int optr = BC*oy*ox + (Y+py)*ox + px + X;
|
||||
|
||||
output[optr] = input[iptr];
|
||||
}""")
|
||||
ctx.save_for_backward(padding, pad2d)
|
||||
pad2d(ctx.cl_queue, [bs*cin, iy, ix], None,
|
||||
get_pad2d_kernel(ctx)(ctx.cl_queue, [bs*cin, iy, ix], None,
|
||||
x.cl, ret.cl,
|
||||
i32(0), i32(0), i32(padding[2]), i32(padding[0]),
|
||||
i32(0), i32(0), i32(ctx.padding[2]), i32(ctx.padding[0]),
|
||||
i32(oy), i32(ox), i32(iy), i32(ix)
|
||||
)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
padding, pad2d = ctx.saved_tensors
|
||||
bs, cin, iy, ix = grad_output.shape
|
||||
oy, ox = iy - padding[2] - padding[3], ix - padding[0] - padding[1]
|
||||
oy, ox = iy - ctx.padding[2] - ctx.padding[3], ix - ctx.padding[0] - ctx.padding[1]
|
||||
ret = buffer_new(ctx, (bs, cin, oy, ox))
|
||||
pad2d(ctx.cl_queue, [bs*cin, oy, ox], None,
|
||||
get_pad2d_kernel(ctx)(ctx.cl_queue, [bs*cin, oy, ox], None,
|
||||
grad_output.cl, ret.cl,
|
||||
i32(padding[2]), i32(padding[0]), i32(0), i32(0),
|
||||
i32(ctx.padding[2]), i32(ctx.padding[0]), i32(0), i32(0),
|
||||
i32(oy), i32(ox), i32(iy), i32(ix)
|
||||
)
|
||||
return ret
|
||||
register('pad2d', Pad2D, device=Device.GPU)
|
||||
|
||||
# TODO: this is an exact copy from the CPU code
|
||||
class Unpad2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, padding=None):
|
||||
return Pad2D.backward(ctx, x)
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return Pad2D.forward(ctx, grad_output)
|
||||
register('unpad2d', Unpad2D, device=Device.GPU)
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shape):
|
||||
|
||||
@@ -291,7 +291,7 @@ def register(name, fxn, device=Device.CPU):
|
||||
def dispatch(*x, **kwargs):
|
||||
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
|
||||
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
f = (Tensor.ops[tt.device])[name]
|
||||
f = Tensor.ops[tt.device][name]
|
||||
f.cl_ctx, f.cl_queue, f.ane, f.device = cl_ctx, cl_queue, ane, tt.device
|
||||
return f.apply(f, *x, **kwargs)
|
||||
setattr(Tensor, name, dispatch)
|
||||
|
||||
Reference in New Issue
Block a user