mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
hotfix: that repeat_kv belongs outside the if
This commit is contained in:
@@ -82,11 +82,11 @@ class Attention:
|
||||
|
||||
keys = self.cache_kv[0, :, 0:start_pos+seqlen, :, :]
|
||||
values = self.cache_kv[1, :, 0:start_pos+seqlen, :, :]
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
else:
|
||||
assert start_pos == 0
|
||||
keys, values = xk, xv
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
|
||||
attn = attn.reshape(bsz, seqlen, -1)
|
||||
|
||||
Reference in New Issue
Block a user