diff --git a/test/unit/test_attention.py b/test/unit/test_attention.py index 2ef112b741..7eab6c80c8 100644 --- a/test/unit/test_attention.py +++ b/test/unit/test_attention.py @@ -70,11 +70,14 @@ class TestGatedDeltaNetBlock(unittest.TestCase): return block._attention(x_norm, start_pos).realize().numpy() def _cache_views(self, block:GatedDeltaNetBlock) -> tuple[np.ndarray, np.ndarray]: - conv_flat = (block.ssm_conv_kernel - 1) * block.conv_channels - cache = block.delta_cache.numpy() - conv_state = cache[:, :conv_flat].reshape(cache.shape[0], block.ssm_conv_kernel - 1, block.conv_channels) - recurrent_state = cache[:, conv_flat:].reshape(cache.shape[0], block.num_v_heads, block.head_v_dim, block.head_v_dim) - return conv_state, recurrent_state + if hasattr(block, 'conv_state'): + return block.conv_state.numpy(), block.recurrent_state.numpy() + else: + conv_flat = (block.ssm_conv_kernel - 1) * block.conv_channels + cache = block.delta_cache.numpy() + conv_state = cache[:, :conv_flat].reshape(cache.shape[0], block.ssm_conv_kernel - 1, block.conv_channels) + recurrent_state = cache[:, conv_flat:].reshape(cache.shape[0], block.num_v_heads, block.head_v_dim, block.head_v_dim) + return conv_state, recurrent_state def _linear_np(self, x:np.ndarray, weight:np.ndarray) -> np.ndarray: return x.astype(np.float32) @ weight.T.astype(np.float32) diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 8f7f066f82..4f5ae6b570 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -321,13 +321,9 @@ class GatedDeltaNetBlock(FFNBlock): beta = self.ssm_beta(x).sigmoid().reshape(B, self.num_v_heads, 1, 1) alpha = ((self.ssm_alpha(x).float() + self.ssm_dt["bias"]).softplus() * self.ssm_a).reshape(B, self.num_v_heads, 1, 1).exp() - # conv - conv_flat = (self.ssm_conv_kernel - 1) * self.conv_channels - conv_state = self.delta_cache[:, :conv_flat].reshape(B, self.ssm_conv_kernel - 1, self.conv_channels) - conv_window = conv_state.cat(self.attn_qkv(x), dim=1) + # qkv conv + conv_window = self.conv_state.cat(self.attn_qkv(x), dim=1) conv_out = (conv_window * self.ssm_conv1d["weight"].T.unsqueeze(0)).sum(1).silu() - - # qkv q, k, v = conv_out.split([self.q_dim, self.q_dim, self.conv_channels - 2*self.q_dim], dim=-1) q = q.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1) k = k.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1) @@ -335,28 +331,28 @@ class GatedDeltaNetBlock(FFNBlock): q, k, v = q.mul(self.head_k_dim**-0.5).unsqueeze(-1), k.unsqueeze(-1), v.unsqueeze(-1) # recurrent - ssm_flat = self.num_v_heads * self.head_v_dim * self.head_v_dim - recurrent_state = self.delta_cache[:, conv_flat:conv_flat + ssm_flat].reshape(B, self.num_v_heads, self.head_v_dim, self.head_v_dim) - recurrent_state = recurrent_state * alpha + recurrent_state = self.recurrent_state * alpha recurrent_state = recurrent_state + ((v - recurrent_state@k) * beta)@k.transpose(-1, -2) - new_cache = conv_window[:, 1:, :].reshape(B, -1).cat(recurrent_state.reshape(B, -1), dim=-1).contiguous() - assigned = self.delta_cache.uop.after(self.delta_cache.uop.store(new_cache.cast(self.delta_cache.dtype).uop)) - cache_tensor = Tensor(assigned, device=self.delta_cache.device) - # final - final_state = cache_tensor[:, conv_flat:conv_flat + ssm_flat].reshape(B, self.num_v_heads, self.head_v_dim, self.head_v_dim) - core_attn_out = self.ssm_norm((final_state@q).squeeze(-1).reshape(B, 1, self.num_v_heads, self.head_v_dim)) + # store the updated state + conv_state_store = self.conv_state.uop.store(conv_window[:, 1:, :].cast(self.conv_state.dtype).uop) + recurrent_state_store = self.recurrent_state.uop.store(recurrent_state.cast(self.recurrent_state.dtype).uop) + recurrent_state = Tensor(self.recurrent_state.uop.after(recurrent_state_store, conv_state_store)) + + # output + core_attn_out = self.ssm_norm((recurrent_state@q).squeeze(-1).reshape(B, 1, self.num_v_heads, self.head_v_dim)) return self.ssm_out((core_attn_out * out_gate.silu()).reshape(B, 1, -1).cast(x.dtype)) # recurrent state can't be partially reused after divergence, force a full rebuild - def _state_reset_ops(self): return [self.delta_cache.assign(Tensor.zeros_like(self.delta_cache))] if hasattr(self, "delta_cache") else [] + def _state_reset_ops(self): + return [self.conv_state.assign(Tensor.zeros_like(self.conv_state)), + self.recurrent_state.assign(Tensor.zeros_like(self.recurrent_state))] if hasattr(self, "conv_state") else [] def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return 0 if prefix_len != cached_len else prefix_len def _init_state(self, x): - if not hasattr(self, "delta_cache"): - conv_flat = (self.ssm_conv_kernel - 1) * self.conv_channels - ssm_flat = self.num_v_heads * self.head_v_dim * self.head_v_dim - self.delta_cache = Tensor.zeros(x.shape[0], conv_flat + ssm_flat, device=x.device).clone() + if not hasattr(self, "conv_state"): + self.conv_state = Tensor.zeros(x.shape[0], self.ssm_conv_kernel-1, self.conv_channels, device=x.device).clone() + self.recurrent_state = Tensor.zeros(x.shape[0], self.num_v_heads, self.head_v_dim, self.head_v_dim, device=x.device).clone() class Transformer: def __init__(self, config:TransformerConfig):