llama: contig backward for wk / wv matmul backward (#14581)

This commit is contained in:
qazal
2026-02-06 00:54:00 -05:00
committed by GitHub
parent 15d3344d9e
commit be77873974

View File

@@ -55,7 +55,7 @@ class Attention:
xqkv = x @ self.wqkv.T
xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
else:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)
if self.q_norm is not None and self.k_norm is not None:
xq = self.q_norm(xq)