diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 53d8544bf8..5160a46ac0 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -56,16 +56,13 @@ class SimpleTokenizer: def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor: freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim)) freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0) - return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).contiguous() + return freqs.cos().cat(freqs.sin(), dim=-1).contiguous() def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor: - B, H, T, Hd = x.shape - assert isinstance(Hd, int) and (Hd & 1) == 0, "RoPE requires an even head dimension" - x_pairs = x.reshape(B, H, T, Hd//2, 2) - cos = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 0] - sin = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 1] - return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin, - x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd) + assert x.shape[-1] % 2 == 0 + cos, sin = freqs_cis.reshape(1, 1, x.shape[2], -1).chunk(2, dim=-1) + x1, x2 = x.chunk(2, dim=-1) + return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1) class TransformerBlock: def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0): @@ -159,8 +156,15 @@ class Transformer: arch = kv['general.architecture'] max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length'] + n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv'] + + # permute Q/K weights from interleaved to half-split RoPE layout: [0,1,2,3,4,5...] -> [0,2,4,...,1,3,5,...] + for name in state_dict: + if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2) + if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2) + model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'], - n_heads=kv[f'{arch}.attention.head_count'], n_kv_heads=kv[f'{arch}.attention.head_count_kv'], + n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context) nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused # NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster