feat: llama wqkv (#14841)

This commit is contained in:
wozeparrot
2026-02-17 23:01:33 -08:00
committed by GitHub
parent a3d516c4b5
commit 6d301ad2c4
3 changed files with 11 additions and 6 deletions

View File

@@ -11,6 +11,7 @@ export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
export ALL2ALL=${ALL2ALL:-1}
export USE_ATOMICS=${USE_ATOMICS:-1}
export ASM_GEMM=${ASM_GEMM:-1}
export WQKV=${WQKV:-0}
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}

View File

@@ -11,6 +11,7 @@ export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
export ALL2ALL=${ALL2ALL:-1}
export USE_ATOMICS=${USE_ATOMICS:-1}
export ASM_GEMM=${ASM_GEMM:-1}
export WQKV=${WQKV:-0}
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}

View File

@@ -41,9 +41,13 @@ class Attention:
self.n_rep = self.n_heads // self.n_kv_heads
self.max_context = max_context
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
if getenv("WQKV"):
self.wqkv = linear(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2, bias=False)
else:
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
@@ -51,9 +55,8 @@ class Attention:
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]=None) -> Tensor:
if getenv("WQKV"):
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
xqkv = x @ self.wqkv.T
xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
xqkv = self.wqkv(x)
xq, xk, xv = xqkv.split([self.n_heads * self.head_dim, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim], dim=2)
else:
xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)