From 6d301ad2c46004b10738f779073e95e1edd082e4 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Tue, 17 Feb 2026 23:01:33 -0800 Subject: [PATCH] feat: llama wqkv (#14841) --- .../implementations/tinybox_8xMI350X/dev_beam.sh | 1 + .../implementations/tinybox_8xMI350X/dev_run.sh | 1 + extra/models/llama.py | 15 +++++++++------ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh index e8d58658a3..62c1048632 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh @@ -11,6 +11,7 @@ export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1} export ALL2ALL=${ALL2ALL:-1} export USE_ATOMICS=${USE_ATOMICS:-1} export ASM_GEMM=${ASM_GEMM:-1} +export WQKV=${WQKV:-0} export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16" export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2} diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh index a239d06179..c729d1b947 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh @@ -11,6 +11,7 @@ export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1} export ALL2ALL=${ALL2ALL:-1} export USE_ATOMICS=${USE_ATOMICS:-1} export ASM_GEMM=${ASM_GEMM:-1} +export WQKV=${WQKV:-0} export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16" export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2} diff --git a/extra/models/llama.py b/extra/models/llama.py index feff880202..45bd3d4045 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -41,9 +41,13 @@ class Attention: self.n_rep = self.n_heads // self.n_kv_heads self.max_context = max_context - self.wq = linear(dim, self.n_heads * self.head_dim, bias=False) - self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False) + if getenv("WQKV"): + self.wqkv = linear(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2, bias=False) + else: + self.wq = linear(dim, self.n_heads * self.head_dim, bias=False) + self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = linear(self.n_heads * self.head_dim, dim, bias=False) self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None @@ -51,9 +55,8 @@ class Attention: def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]=None) -> Tensor: if getenv("WQKV"): - if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight) - xqkv = x @ self.wqkv.T - xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2) + xqkv = self.wqkv(x) + xq, xk, xv = xqkv.split([self.n_heads * self.head_dim, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim], dim=2) else: xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)