one less loop

This commit is contained in:
George Hotz
2020-10-19 09:34:55 -07:00
parent 5c2ac48c11
commit bee89a4840

View File

@@ -163,10 +163,9 @@ class Conv2D(Function):
for X in range(ret.shape[3]):
for j in range(H):
for i in range(W):
for c in range(cout):
tx = x[:, :, Y+j, X+i]
tw = w[c, :, j, i]
ret[:, c, Y, X] += tx.dot(tw.reshape(-1, 1)).reshape(-1)
tx = x[:, :, Y+j, X+i]
tw = w[:, :, j, i]
ret[:, :, Y, X] += tx.dot(tw.T)
return ret
@staticmethod