From 3a480b858f16ca2dfef20c2df978900e6ca7a64f Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 30 Sep 2025 11:08:03 +0800 Subject: [PATCH] use more getitem in gpt2 (#12343) --- examples/gpt2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/gpt2.py b/examples/gpt2.py index 2ec1810e02..8f1a2836af 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -26,8 +26,8 @@ class Attention: start_pos = start_pos.val if HALF: x = x.half() - xqkv = self.c_attn(x) - xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(None, None, self.n_heads, self.head_dim) for i in range(3)] + xqkv = self.c_attn(x).reshape(None, None, 3, self.n_heads, self.head_dim) + xq, xk, xv = [xqkv[:, :, i, :, :] for i in range(3)] bsz, seqlen, _, _ = xq.shape # create kv cache @@ -35,11 +35,11 @@ class Attention: self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize() # update the cache - self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack(xk, xv)).realize() + self.cache_kv[:, :, start_pos:start_pos+seqlen, :, :].assign(Tensor.stack(xk, xv)).realize() if start_pos > 0: - keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) - values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) + keys = self.cache_kv[0][:, :start_pos+seqlen, :, :] + values = self.cache_kv[1][:, :start_pos+seqlen, :, :] else: keys = xk values = xv