mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
feat: llama wqkv (#14841)
This commit is contained in:
@@ -11,6 +11,7 @@ export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
|
||||
export ALL2ALL=${ALL2ALL:-1}
|
||||
export USE_ATOMICS=${USE_ATOMICS:-1}
|
||||
export ASM_GEMM=${ASM_GEMM:-1}
|
||||
export WQKV=${WQKV:-0}
|
||||
|
||||
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
|
||||
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}
|
||||
|
||||
@@ -11,6 +11,7 @@ export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
|
||||
export ALL2ALL=${ALL2ALL:-1}
|
||||
export USE_ATOMICS=${USE_ATOMICS:-1}
|
||||
export ASM_GEMM=${ASM_GEMM:-1}
|
||||
export WQKV=${WQKV:-0}
|
||||
|
||||
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
|
||||
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}
|
||||
|
||||
@@ -41,9 +41,13 @@ class Attention:
|
||||
self.n_rep = self.n_heads // self.n_kv_heads
|
||||
self.max_context = max_context
|
||||
|
||||
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
if getenv("WQKV"):
|
||||
self.wqkv = linear(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2, bias=False)
|
||||
else:
|
||||
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
|
||||
@@ -51,9 +55,8 @@ class Attention:
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]=None) -> Tensor:
|
||||
if getenv("WQKV"):
|
||||
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
|
||||
xqkv = x @ self.wqkv.T
|
||||
xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
|
||||
xqkv = self.wqkv(x)
|
||||
xq, xk, xv = xqkv.split([self.n_heads * self.head_dim, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim], dim=2)
|
||||
else:
|
||||
xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user