no einsum for now

This commit is contained in:
George Hotz
2022-07-09 00:04:40 -07:00
parent c39a245696
commit 0a36475700

View File

@@ -51,4 +51,10 @@ class CPUBuffer(np.ndarray):
shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W),
strides=(*gx.strides[0:3], gx.strides[3]*C.sy, gx.strides[4]*C.sx, gx.strides[3]*C.dy, gx.strides[4]*C.dx))
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
return np.einsum("nGChwHW, GkCHW -> nGkhw", tx, tw).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer)
# too bad this doesn't mix with stride_tricks, it can be very slow
#out = np.einsum("nGChwHW, GkCHW -> nGkhw", tx, tw)
tmp = np.empty((C.bs,C.groups,C.oy,C.ox,C.rcout),dtype=x.dtype)
for g in range(C.groups): tmp[:,g] = np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
return np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer)