mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
make multidot work on CPU
This commit is contained in:
@@ -48,14 +48,16 @@ class TransformerBlock:
|
||||
|
||||
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
||||
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, T)
|
||||
score = query.dot(key)
|
||||
print(query.shape)
|
||||
print(key.shape)
|
||||
print(score.shape)
|
||||
|
||||
|
||||
#score = query.reshape(shape=(-1, self.projection_dim)).dot(
|
||||
# key.reshape(shape=(-1, self.projection_dim)).transpose(order=(1,0)))
|
||||
#scaled_score = score * (1/np.sqrt(self.projection_dim))
|
||||
|
||||
print(query.shape)
|
||||
print(key.shape)
|
||||
#print(value.shape)
|
||||
#print(scaled_score.shape)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user