hotfix: that repeat_kv belongs outside the if

This commit is contained in:
George Hotz
2025-05-11 18:41:57 -07:00
parent 98c84a711d
commit 8864ff894b

View File

@@ -82,11 +82,11 @@ class Attention:
keys = self.cache_kv[0, :, 0:start_pos+seqlen, :, :]
values = self.cache_kv[1, :, 0:start_pos+seqlen, :, :]
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
else:
assert start_pos == 0
keys, values = xk, xv
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
attn = attn.reshape(bsz, seqlen, -1)