use more getitem in gpt2 (#12343)

This commit is contained in:
chenyu
2025-09-30 11:08:03 +08:00
committed by GitHub
parent 32d69d07d7
commit 3a480b858f

View File

@@ -26,8 +26,8 @@ class Attention:
start_pos = start_pos.val
if HALF: x = x.half()
xqkv = self.c_attn(x)
xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(None, None, self.n_heads, self.head_dim) for i in range(3)]
xqkv = self.c_attn(x).reshape(None, None, 3, self.n_heads, self.head_dim)
xq, xk, xv = [xqkv[:, :, i, :, :] for i in range(3)]
bsz, seqlen, _, _ = xq.shape
# create kv cache
@@ -35,11 +35,11 @@ class Attention:
self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
# update the cache
self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack(xk, xv)).realize()
self.cache_kv[:, :, start_pos:start_pos+seqlen, :, :].assign(Tensor.stack(xk, xv)).realize()
if start_pos > 0:
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None))
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None))
keys = self.cache_kv[0][:, :start_pos+seqlen, :, :]
values = self.cache_kv[1][:, :start_pos+seqlen, :, :]
else:
keys = xk
values = xv