diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 90e135c2a2..234cceb224 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -508,6 +508,8 @@ jobs: run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model llama3.2:1b | tee /dev/stderr | grep -i rooster - name: Test 1B LLM (llama q4) run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model llama3.2:1b-q4 | tee /dev/stderr | grep -i rooster + - name: Test 1B LLM (qwen3.5) + run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model qwen3.5:0.8b | tee /dev/stderr | grep -i rooster - name: Test 1B LLM (qwen) # NOTE: qwen is dumb and only knows about female chickens run: echo "What's a female chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model qwen3:0.6b | tee /dev/stderr | grep -i hen diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 9c285c7d4c..b07324397b 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -98,6 +98,14 @@ def pairwise_topk(x: Tensor, k: int) -> tuple[Tensor, Tensor]: sel = Tensor.zeros_like(x).scatter(-1, cmp.sum(axis=-1).cast('int32'), vals)[:,:,n-k:].cast('int32') return x.gather(-1, sel), sel +@dataclass(frozen=True) +class SSMConfig: + conv_kernel: int + state_size: int + group_count: int + time_step_rank: int + inner_size: int + @dataclass(frozen=True) class TransformerConfig: num_blocks: int @@ -118,6 +126,9 @@ class TransformerConfig: norm_topk_prob: bool = False kv_lora_rank: int = 0 shared_expert_dim: int = 0 + full_attention_interval: int = 0 + attn_output_gate: bool = False + ssm: SSMConfig|None = None shared_expert_gate: bool = True leading_dense_blocks: int = 0 dense_hidden_dim: int = 0 @@ -171,6 +182,10 @@ class FFNBlock: # TODO: remove the need for this contiguous return self.ffn_down(self.ffn_gate(x).silu().contiguous() * self.ffn_up(x)) + # given the token-prefix match, return how much cached state this block can still reuse + def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return prefix_len + # return writes that reset this block's state after a cache mismatch + def _state_reset_ops(self) -> list[Tensor]: return [] def _init_state(self, x:Tensor): raise NotImplementedError def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: raise NotImplementedError @@ -189,12 +204,12 @@ class TransformerBlock(FFNBlock): assert config.v_head_dim == config.head_dim, "TransformerBlock requires v_head_dim == head_dim" # --- attention projections (all linear, bias-free) ------------------ - q_proj_out = config.head_dim * config.n_heads + q_proj_out = config.head_dim * config.n_heads * (2 if config.attn_output_gate else 1) kv_proj_out = config.head_dim * config.n_kv_heads self.attn_q = nn.Linear(config.dim, q_proj_out, bias=False) self.attn_k = nn.Linear(config.dim, kv_proj_out, bias=False) self.attn_v = nn.Linear(config.dim, kv_proj_out, bias=False) - self.attn_output = nn.Linear(q_proj_out, config.dim, bias=False) + self.attn_output = nn.Linear(config.head_dim * config.n_heads, config.dim, bias=False) if config.qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(config.qk_norm, config.norm_eps), nn.RMSNorm(config.qk_norm, config.norm_eps) def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: @@ -202,6 +217,9 @@ class TransformerBlock(FFNBlock): if self.config.qk_norm and self.config.qk_norm != self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k) B, T, _ = x.shape + if self.config.attn_output_gate: + qg = q.reshape(B, T, self.config.n_heads, 2, self.config.head_dim) + q, gate = qg[:, :, :, 0, :], qg[:, :, :, 1, :].reshape(B, T, self.config.n_heads * self.config.head_dim) q = q.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2) # (B,H,T,Hd) k = k.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd) v = v.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd) @@ -224,7 +242,7 @@ class TransformerBlock(FFNBlock): mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd) attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D) - return self.attn_output(attn) + return self.attn_output(attn if not self.config.attn_output_gate else (attn * gate.sigmoid())) def _init_state(self, x:Tensor): if not hasattr(self, "cache_kv"): @@ -274,15 +292,71 @@ class MLATransformerBlock(FFNBlock): self.cache_v = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank, device=x.device) self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta) +class GatedDeltaNetBlock(FFNBlock): + def __init__(self, config:TransformerConfig, ssm:SSMConfig): + super().__init__(config) + self.head_k_dim, self.num_k_heads, self.num_v_heads = ssm.state_size, ssm.group_count, ssm.time_step_rank + self.head_v_dim, self.ssm_conv_kernel = ssm.inner_size // ssm.time_step_rank, ssm.conv_kernel + self.conv_channels, self.q_dim = ssm.inner_size + 2*ssm.group_count*ssm.state_size, ssm.state_size*ssm.group_count + self.attn_qkv, self.attn_gate = nn.Linear(config.dim, self.conv_channels, bias=False), nn.Linear(config.dim, ssm.inner_size, bias=False) + self.ssm_alpha, self.ssm_beta = nn.Linear(config.dim, self.num_v_heads, bias=False), nn.Linear(config.dim, self.num_v_heads, bias=False) + self.ssm_conv1d = {"weight": Tensor.zeros(self.conv_channels, self.ssm_conv_kernel)} + self.ssm_dt = {"bias": Tensor.zeros(self.num_v_heads)} + self.ssm_a = Tensor.zeros(self.num_v_heads) + self.ssm_norm, self.ssm_out = nn.RMSNorm(self.head_v_dim, config.norm_eps), nn.Linear(ssm.inner_size, config.dim, bias=False) + + def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: + B, T, _ = x.shape + assert T == 1, "GatedDeltaNetBlock currently only supports T=1" + x = x.half() + out_gate = self.attn_gate(x).reshape(B, 1, self.num_v_heads, self.head_v_dim) + beta = self.ssm_beta(x).sigmoid().reshape(B, self.num_v_heads, 1, 1) + alpha = ((self.ssm_alpha(x).float() + self.ssm_dt["bias"]).softplus() * self.ssm_a).reshape(B, self.num_v_heads, 1, 1).exp() + conv_flat = (self.ssm_conv_kernel - 1) * self.conv_channels + ssm_flat = self.num_v_heads * self.head_v_dim * self.head_v_dim + conv_state = self.delta_cache[:, :conv_flat].reshape(B, self.ssm_conv_kernel - 1, self.conv_channels) + recurrent_state = self.delta_cache[:, conv_flat:conv_flat + ssm_flat].reshape(B, self.num_v_heads, self.head_v_dim, self.head_v_dim) + conv_window = conv_state.cat(self.attn_qkv(x), dim=1) + conv_out = (conv_window * self.ssm_conv1d["weight"].T.unsqueeze(0)).sum(1).silu() + q, k, v = conv_out.split([self.q_dim, self.q_dim, self.conv_channels - 2*self.q_dim], dim=-1) + q, k = q.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1), k.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1) + v = v.reshape(B, self.num_v_heads, self.head_v_dim) + if self.num_v_heads != self.num_k_heads: + k_repeat = self.num_v_heads // self.num_k_heads + q = q.unsqueeze(1).expand(B, k_repeat, self.num_k_heads, self.head_k_dim).reshape(B, self.num_v_heads, self.head_k_dim) + k = k.unsqueeze(1).expand(B, k_repeat, self.num_k_heads, self.head_k_dim).reshape(B, self.num_v_heads, self.head_k_dim) + q, k, v = (q * self.head_k_dim**-0.5).unsqueeze(-1), k.unsqueeze(-1), v.unsqueeze(-1) + recurrent_state = recurrent_state * alpha + recurrent_state = recurrent_state + ((v - recurrent_state@k) * beta)@k.transpose(-1, -2) + new_cache = conv_window[:, 1:, :].reshape(B, -1).cat(recurrent_state.reshape(B, -1), dim=-1).contiguous() + assigned = self.delta_cache.uop.after(self.delta_cache.uop.store(new_cache.cast(self.delta_cache.dtype).uop)) + cache_tensor = Tensor(assigned, device=self.delta_cache.device) + final_state = cache_tensor[:, conv_flat:conv_flat + ssm_flat].reshape(B, self.num_v_heads, self.head_v_dim, self.head_v_dim) + core_attn_out = self.ssm_norm((final_state@q).squeeze(-1).reshape(B, 1, self.num_v_heads, self.head_v_dim)) + return self.ssm_out((core_attn_out * out_gate.silu()).reshape(B, 1, -1).cast(x.dtype)) + + # recurrent state can't be partially reused after divergence, force a full rebuild + def _state_reset_ops(self): return [self.delta_cache.assign(Tensor.zeros_like(self.delta_cache))] if hasattr(self, "delta_cache") else [] + def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return 0 if prefix_len != cached_len else prefix_len + + def _init_state(self, x): + if not hasattr(self, "delta_cache"): + conv_flat = (self.ssm_conv_kernel - 1) * self.conv_channels + ssm_flat = self.num_v_heads * self.head_v_dim * self.head_v_dim + self.delta_cache = Tensor.zeros(x.shape[0], conv_flat + ssm_flat, device=x.device).clone() + class Transformer: def __init__(self, config:TransformerConfig): dense_config = replace(config, num_experts=0, num_experts_per_tok=0, shared_expert_dim=0, hidden_dim=config.dense_hidden_dim or config.hidden_dim) + if config.ssm: config = replace(config, qk_norm=config.head_dim) block_cls = MLATransformerBlock if config.kv_lora_rank > 0 else TransformerBlock - self.blk = [block_cls(dense_config if i < config.leading_dense_blocks else config) for i in range(config.num_blocks)] + self.blk:list[FFNBlock] = [GatedDeltaNetBlock(config, config.ssm) if config.ssm and (i+1) % config.full_attention_interval != 0 else + block_cls(dense_config if i < config.leading_dense_blocks else config) for i in range(config.num_blocks)] self.token_embd = nn.Embedding(config.vocab_size, config.dim) self.output_norm = nn.RMSNorm(config.dim, config.norm_eps) self.output = nn.Linear(config.dim, config.vocab_size, bias=False) self.max_context = config.max_context + self.has_recurrent_block = any(isinstance(b, GatedDeltaNetBlock) for b in self.blk) self._cached_tokens: list[int] = [] # we specialize the JIT for prefill and rollout self.prefill_jit = TinyJit(self.forward) @@ -296,7 +370,7 @@ class Transformer: return (logits / temperature.maximum(1e-12) - (Tensor.rand_like(logits).maximum(1e-12).log().neg()).log()).argmax(-1, keepdim=True) def __call__(self, tokens:Tensor, start_pos:int|UOp, temperature:Tensor) -> Tensor: - return (self.prefill_jit if resolve(tokens.shape[1] != 1) else self.rollout_jit)(tokens, start_pos, temperature) + return (self.prefill_jit if resolve(tokens.shape[1] != 1) else self.rollout_jit)(tokens.contiguous(), start_pos, temperature) @staticmethod def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 0))) -> tuple[Transformer, dict]: @@ -313,6 +387,11 @@ class Transformer: max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length'] n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv'] + ssm = None + if arch in ('qwen35', 'qwen35moe'): + ssm = SSMConfig(**{k: kv[f'{arch}.ssm.{k}'] for k in ('conv_kernel','state_size','group_count','time_step_rank','inner_size')}) + state_dict = {k.replace('post_attention_norm', 'ffn_norm'):v for k,v in state_dict.items()} + kv_lora_rank = kv.get(f'{arch}.attention.kv_lora_rank', 0) head_dim = kv.get(f'{arch}.attention.key_length_mla', kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads)) rope_dim = kv.get(f'{arch}.rope.dimension_count', head_dim) @@ -330,7 +409,7 @@ class Transformer: state_dict[name] = state_dict[name][:kv_lora_rank].cat(state_dict[name][kv_lora_rank:].rearrange("(h two) d -> (two h) d", two=2), dim=0) config = TransformerConfig( num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], - hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']), + hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv.get(f'{arch}.feed_forward_length', 0)), n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), head_dim=head_dim, @@ -348,7 +427,8 @@ class Transformer: kv.get(f'{arch}.expert_shared_count', 0) * kv.get(f'{arch}.expert_feed_forward_length', 0)), shared_expert_gate=f"blk.{kv.get(f'{arch}.leading_dense_block_count', 0)}.ffn_gate_inp_shexp.weight" in state_dict, dense_hidden_dim=kv.get(f'{arch}.feed_forward_length', 0) if kv.get(f'{arch}.leading_dense_block_count', 0) else 0, - routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0)) + routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0), attn_output_gate=arch in ('qwen35', 'qwen35moe'), ssm=ssm, + full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0)) model = Transformer(config) nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused # NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster @@ -357,18 +437,21 @@ class Transformer: Tensor.realize(*params) return model, kv - def get_start_pos(self, tokens:list[int]): - return sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens[:-1], self._cached_tokens))) + def get_start_pos(self, tokens:list[int]) -> int: + prefix_len = sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens[:-1], self._cached_tokens))) + return min(block._reusable_prefix_len(prefix_len, len(self._cached_tokens)) for block in self.blk) def generate(self, tokens:list[int], chunk_size:int=32, temperature:float=0.0): + if self.has_recurrent_block: chunk_size = 1 v_start_pos = UOp.variable("start_pos", 0, self.max_context-1) v_toks = UOp.variable("toks", 1, chunk_size) # TODO: use UOp.variable for temperature once float variables are supported temp = Tensor(temperature).contiguous() # assign all input tokens once, then slice from start_pos for the model call t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context) - # recompute start_pos from what's currently valid in the kv cache + # recompute start_pos from what's currently valid in the caches start_pos = self.get_start_pos(tokens) + if start_pos < len(self._cached_tokens) and (resets := [r for b in self.blk for r in b._state_reset_ops()]): Tensor.realize(*resets) out, prompt_len = None, len(tokens) while len(tokens) < self.max_context: sp, nt = v_start_pos.bind(start_pos), v_toks.bind(min(chunk_size, len(tokens) - start_pos)) @@ -390,6 +473,11 @@ models = { "qwen3:1.7b": "https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_K_M.gguf", "qwen3:8b": "https://huggingface.co/Qwen/Qwen3-8B-GGUF/resolve/main/Qwen3-8B-Q4_K_M.gguf", "qwen3:30b-a3b": "https://huggingface.co/Qwen/Qwen3-30B-A3B-GGUF/resolve/main/Qwen3-30B-A3B-Q4_K_M.gguf", + "qwen3.5:0.8b": "https://huggingface.co/unsloth/Qwen3.5-0.8B-GGUF/resolve/main/Qwen3.5-0.8B-Q8_0.gguf", + "qwen3.5:4b": "https://huggingface.co/unsloth/Qwen3.5-4B-GGUF/resolve/main/Qwen3.5-4B-Q4_K_M.gguf", + "qwen3.5:9b": "https://huggingface.co/unsloth/Qwen3.5-9B-GGUF/resolve/main/Qwen3.5-9B-Q4_K_M.gguf", + "qwen3.5:27b": "https://huggingface.co/unsloth/Qwen3.5-27B-GGUF/resolve/main/Qwen3.5-27B-Q4_K_M.gguf", + "qwen3.5:35b-a3b": "https://huggingface.co/unsloth/Qwen3.5-35B-A3B-GGUF/resolve/main/Qwen3.5-35B-A3B-Q4_K_M.gguf", "olmoe": "https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct-GGUF/resolve/main/olmoe-1b-7b-0924-instruct-q4_k_m.gguf", "moonlight": "https://huggingface.co/gabriellarson/Moonlight-16B-A3B-Instruct-GGUF/resolve/main/Moonlight-16B-A3B-Instruct-Q4_K_M.gguf", }