mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Pad2d backward pass on GPU (#89)
* Pad2d backward pass on GPU * Faster Pad2D GPU backward pass (no zeroing needed) * Fix out of bounds error * Don't save prg Co-authored-by: holonomicjl <58403584+holonomicjl@users.noreply.github.com>
This commit is contained in:
@@ -316,6 +316,7 @@ class Pad2D(Function):
|
||||
}
|
||||
}
|
||||
""")
|
||||
ctx.save_for_backward(padding)
|
||||
prg.pad2d(ctx.cl_queue, [bs, cin, iy], None,
|
||||
x, ret,
|
||||
np.int32(cin), np.int32(padding[0]), np.int32(padding[2]),
|
||||
@@ -325,7 +326,34 @@ class Pad2D(Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise Exception("write this")
|
||||
padding, = ctx.saved_tensors
|
||||
bs, cin, iy, ix = grad_output.shape
|
||||
oy, ox = iy - padding[0] - padding[1], ix - padding[2] - padding[3]
|
||||
ret = buffer_new(ctx, (bs, cin, oy, ox))
|
||||
prg = clbuild(ctx.cl_ctx, """
|
||||
__kernel void pad2d(
|
||||
__global const float *input, __global float *output,
|
||||
int cin, int py, int px, int oy, int ox, int iy, int ix
|
||||
)
|
||||
{
|
||||
int B = get_global_id(0);
|
||||
int C = get_global_id(1);
|
||||
int Y = get_global_id(2);
|
||||
|
||||
int iptr = B*cin*iy*ix + C*iy*ix + (Y+py)*ix + px;
|
||||
int optr = B*cin*oy*ox + C*oy*ox + Y*ox;
|
||||
|
||||
for (int x = 0; x < ox; x++) {
|
||||
output[optr+x] = input[iptr+x];
|
||||
}
|
||||
}
|
||||
""")
|
||||
prg.pad2d(ctx.cl_queue, [bs, cin, oy], None,
|
||||
grad_output, ret,
|
||||
np.int32(cin), np.int32(padding[0]), np.int32(padding[2]),
|
||||
np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix)
|
||||
)
|
||||
return ret
|
||||
register('pad2d', Pad2D, gpu=True)
|
||||
|
||||
class Reshape(Function):
|
||||
|
||||
Reference in New Issue
Block a user