mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
use more getitem in gpt2 (#12343)
This commit is contained in:
@@ -26,8 +26,8 @@ class Attention:
|
||||
start_pos = start_pos.val
|
||||
|
||||
if HALF: x = x.half()
|
||||
xqkv = self.c_attn(x)
|
||||
xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(None, None, self.n_heads, self.head_dim) for i in range(3)]
|
||||
xqkv = self.c_attn(x).reshape(None, None, 3, self.n_heads, self.head_dim)
|
||||
xq, xk, xv = [xqkv[:, :, i, :, :] for i in range(3)]
|
||||
bsz, seqlen, _, _ = xq.shape
|
||||
|
||||
# create kv cache
|
||||
@@ -35,11 +35,11 @@ class Attention:
|
||||
self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
|
||||
|
||||
# update the cache
|
||||
self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack(xk, xv)).realize()
|
||||
self.cache_kv[:, :, start_pos:start_pos+seqlen, :, :].assign(Tensor.stack(xk, xv)).realize()
|
||||
|
||||
if start_pos > 0:
|
||||
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None))
|
||||
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None))
|
||||
keys = self.cache_kv[0][:, :start_pos+seqlen, :, :]
|
||||
values = self.cache_kv[1][:, :start_pos+seqlen, :, :]
|
||||
else:
|
||||
keys = xk
|
||||
values = xv
|
||||
|
||||
Reference in New Issue
Block a user