mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
a little faster and cleaner
This commit is contained in:
@@ -48,13 +48,16 @@ class CPUBuffer(np.ndarray):
|
||||
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
gx = x.ravel().reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
|
||||
tx = np.lib.stride_tricks.as_strided(gx,
|
||||
shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W),
|
||||
strides=(*gx.strides[0:3], gx.strides[3]*C.sy, gx.strides[4]*C.sx, gx.strides[3]*C.dy, gx.strides[4]*C.dx))
|
||||
shape=(C.bs, C.groups, C.cin, C.H, C.W, C.oy, C.ox),
|
||||
strides=(*gx.strides[0:3], gx.strides[3]*C.dy, gx.strides[4]*C.dx, gx.strides[3]*C.sy, gx.strides[4]*C.sx))
|
||||
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
|
||||
|
||||
# too bad this doesn't mix with stride_tricks, it can be very slow
|
||||
#out = np.einsum("nGChwHW, GkCHW -> nGkhw", tx, tw)
|
||||
#out = np.einsum("nGCHWhw, GkCHW -> nGkhw", tx, tw)
|
||||
|
||||
tmp = np.empty((C.bs,C.groups,C.oy,C.ox,C.rcout),dtype=x.dtype)
|
||||
for g in range(C.groups): tmp[:,g] = np.tensordot(tx[:,g:g+1], tw[g:g+1], ((1,2,5,6),(0,2,3,4)))
|
||||
return np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer)
|
||||
# 3 lines is faster than 1
|
||||
tmp = np.empty((C.groups,C.rcout,C.bs,C.oy,C.ox), dtype=x.dtype)
|
||||
for g in range(C.groups): tmp[g] = np.tensordot(tw[g], tx[:,g], ((1,2,3),(1,2,3)))
|
||||
out = np.einsum("Gknhw -> nGkhw", tmp)
|
||||
|
||||
return out.reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer)
|
||||
|
||||
@@ -139,15 +139,13 @@ class GPUBuffer:
|
||||
conv_src = """
|
||||
int B = gid/(groups*rcout); int g = (gid/rcout)%groups; int c = gid % rcout;
|
||||
int Y = get_global_id(1); int X = get_global_id(2); gid = gid*oy*ox + Y*ox + X;
|
||||
for (int ci = 0; ci < cin; ci++) {
|
||||
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
|
||||
int idx_y = y*dy + Y*sy - py;
|
||||
int idx_x = x*dx + X*sx - px;
|
||||
int valid = (idx_y >= 0 && idx_y < iy && idx_x >= 0 && idx_x < ix);
|
||||
acc += valid * input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + clamp(idx_y, 0, iy-1)*ix + clamp(idx_x, 0, ix-1)] * \
|
||||
weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
} }
|
||||
}"""
|
||||
for (int ci = 0; ci < cin; ci++) { for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
|
||||
int idx_y = y*dy + Y*sy - py;
|
||||
int idx_x = x*dx + X*sx - px;
|
||||
int valid = (idx_y >= 0 && idx_y < iy && idx_x >= 0 && idx_x < ix);
|
||||
acc += valid * input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + clamp(idx_y, 0, iy-1)*ix + clamp(idx_x, 0, ix-1)] * \
|
||||
weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
} } }"""
|
||||
elif ret.shape != bufs[0][1].shape: # this is a reduce
|
||||
# reverse operation of expand, this validates inputs
|
||||
# generate loops with combined adjacent reduce axis
|
||||
|
||||
Reference in New Issue
Block a user