From be7787397453f2ac09a76d07fe37e6057fd9785c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 6 Feb 2026 00:54:00 -0500 Subject: [PATCH] llama: contig backward for wk / wv matmul backward (#14581) --- extra/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extra/models/llama.py b/extra/models/llama.py index 5d44cd7d04..3b7db359de 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -55,7 +55,7 @@ class Attention: xqkv = x @ self.wqkv.T xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2) else: - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x) if self.q_norm is not None and self.k_norm is not None: xq = self.q_norm(xq)