diff --git a/extra/models/llama.py b/extra/models/llama.py index 9c0fe2dd9c..56f6c01632 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -82,11 +82,11 @@ class Attention: keys = self.cache_kv[0, :, 0:start_pos+seqlen, :, :] values = self.cache_kv[1, :, 0:start_pos+seqlen, :, :] - keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep) else: assert start_pos == 0 keys, values = xk, xv + keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep) xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2) attn = attn.reshape(bsz, seqlen, -1)