mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -56,16 +56,13 @@ class SimpleTokenizer:
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).contiguous()
|
||||
return freqs.cos().cat(freqs.sin(), dim=-1).contiguous()
|
||||
|
||||
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
||||
B, H, T, Hd = x.shape
|
||||
assert isinstance(Hd, int) and (Hd & 1) == 0, "RoPE requires an even head dimension"
|
||||
x_pairs = x.reshape(B, H, T, Hd//2, 2)
|
||||
cos = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 0]
|
||||
sin = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 1]
|
||||
return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin,
|
||||
x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd)
|
||||
assert x.shape[-1] % 2 == 0
|
||||
cos, sin = freqs_cis.reshape(1, 1, x.shape[2], -1).chunk(2, dim=-1)
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1)
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0):
|
||||
@@ -159,8 +156,15 @@ class Transformer:
|
||||
|
||||
arch = kv['general.architecture']
|
||||
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
|
||||
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
|
||||
|
||||
# permute Q/K weights from interleaved to half-split RoPE layout: [0,1,2,3,4,5...] -> [0,2,4,...,1,3,5,...]
|
||||
for name in state_dict:
|
||||
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
|
||||
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
|
||||
|
||||
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'],
|
||||
n_heads=kv[f'{arch}.attention.head_count'], n_kv_heads=kv[f'{arch}.attention.head_count_kv'],
|
||||
n_heads=n_heads, n_kv_heads=n_kv_heads,
|
||||
norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context)
|
||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||
|
||||
Reference in New Issue
Block a user