diff --git a/extra/models/llama.py b/extra/models/llama.py index 5d44cd7d04..3b7db359de 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -55,7 +55,7 @@ class Attention: xqkv = x @ self.wqkv.T xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2) else: - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x) if self.q_norm is not None and self.k_norm is not None: xq = self.q_norm(xq)