rope half-split (#13706)

* rope half

* nicer

* this

* rearrange
This commit is contained in:
George Hotz
2025-12-15 14:31:11 -05:00
committed by GitHub
parent 2359e88f0c
commit ee4a7ee12f

View File

@@ -56,16 +56,13 @@ class SimpleTokenizer:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).contiguous()
return freqs.cos().cat(freqs.sin(), dim=-1).contiguous()
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
B, H, T, Hd = x.shape
assert isinstance(Hd, int) and (Hd & 1) == 0, "RoPE requires an even head dimension"
x_pairs = x.reshape(B, H, T, Hd//2, 2)
cos = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 0]
sin = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 1]
return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin,
x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd)
assert x.shape[-1] % 2 == 0
cos, sin = freqs_cis.reshape(1, 1, x.shape[2], -1).chunk(2, dim=-1)
x1, x2 = x.chunk(2, dim=-1)
return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1)
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0):
@@ -159,8 +156,15 @@ class Transformer:
arch = kv['general.architecture']
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
# permute Q/K weights from interleaved to half-split RoPE layout: [0,1,2,3,4,5...] -> [0,2,4,...,1,3,5,...]
for name in state_dict:
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'],
n_heads=kv[f'{arch}.attention.head_count'], n_kv_heads=kv[f'{arch}.attention.head_count_kv'],
n_heads=n_heads, n_kv_heads=n_kv_heads,
norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context)
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster