make multidot work on CPU

This commit is contained in:
George Hotz
2020-12-27 17:25:37 -05:00
parent 131e04c90c
commit f15bec6dbc
3 changed files with 12 additions and 6 deletions

View File

@@ -48,14 +48,16 @@ class TransformerBlock:
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, T)
score = query.dot(key)
print(query.shape)
print(key.shape)
print(score.shape)
#score = query.reshape(shape=(-1, self.projection_dim)).dot(
# key.reshape(shape=(-1, self.projection_dim)).transpose(order=(1,0)))
#scaled_score = score * (1/np.sqrt(self.projection_dim))
print(query.shape)
print(key.shape)
#print(value.shape)
#print(scaled_score.shape)