mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llama: contig backward for wk / wv matmul backward (#14581)
This commit is contained in:
@@ -55,7 +55,7 @@ class Attention:
|
||||
xqkv = x @ self.wqkv.T
|
||||
xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
|
||||
else:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)
|
||||
|
||||
if self.q_norm is not None and self.k_norm is not None:
|
||||
xq = self.q_norm(xq)
|
||||
|
||||
Reference in New Issue
Block a user