simplify GatedDeltaNetBlock using two state tensors (#15704)

* test double after

* simpler ssm

* no double test
This commit is contained in:
George Hotz
2026-04-16 21:14:00 +08:00
committed by GitHub
parent c04f3eaa70
commit f57380cbc2
2 changed files with 24 additions and 25 deletions

View File

@@ -70,11 +70,14 @@ class TestGatedDeltaNetBlock(unittest.TestCase):
return block._attention(x_norm, start_pos).realize().numpy()
def _cache_views(self, block:GatedDeltaNetBlock) -> tuple[np.ndarray, np.ndarray]:
conv_flat = (block.ssm_conv_kernel - 1) * block.conv_channels
cache = block.delta_cache.numpy()
conv_state = cache[:, :conv_flat].reshape(cache.shape[0], block.ssm_conv_kernel - 1, block.conv_channels)
recurrent_state = cache[:, conv_flat:].reshape(cache.shape[0], block.num_v_heads, block.head_v_dim, block.head_v_dim)
return conv_state, recurrent_state
if hasattr(block, 'conv_state'):
return block.conv_state.numpy(), block.recurrent_state.numpy()
else:
conv_flat = (block.ssm_conv_kernel - 1) * block.conv_channels
cache = block.delta_cache.numpy()
conv_state = cache[:, :conv_flat].reshape(cache.shape[0], block.ssm_conv_kernel - 1, block.conv_channels)
recurrent_state = cache[:, conv_flat:].reshape(cache.shape[0], block.num_v_heads, block.head_v_dim, block.head_v_dim)
return conv_state, recurrent_state
def _linear_np(self, x:np.ndarray, weight:np.ndarray) -> np.ndarray:
return x.astype(np.float32) @ weight.T.astype(np.float32)

View File

@@ -321,13 +321,9 @@ class GatedDeltaNetBlock(FFNBlock):
beta = self.ssm_beta(x).sigmoid().reshape(B, self.num_v_heads, 1, 1)
alpha = ((self.ssm_alpha(x).float() + self.ssm_dt["bias"]).softplus() * self.ssm_a).reshape(B, self.num_v_heads, 1, 1).exp()
# conv
conv_flat = (self.ssm_conv_kernel - 1) * self.conv_channels
conv_state = self.delta_cache[:, :conv_flat].reshape(B, self.ssm_conv_kernel - 1, self.conv_channels)
conv_window = conv_state.cat(self.attn_qkv(x), dim=1)
# qkv conv
conv_window = self.conv_state.cat(self.attn_qkv(x), dim=1)
conv_out = (conv_window * self.ssm_conv1d["weight"].T.unsqueeze(0)).sum(1).silu()
# qkv
q, k, v = conv_out.split([self.q_dim, self.q_dim, self.conv_channels - 2*self.q_dim], dim=-1)
q = q.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1)
k = k.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1)
@@ -335,28 +331,28 @@ class GatedDeltaNetBlock(FFNBlock):
q, k, v = q.mul(self.head_k_dim**-0.5).unsqueeze(-1), k.unsqueeze(-1), v.unsqueeze(-1)
# recurrent
ssm_flat = self.num_v_heads * self.head_v_dim * self.head_v_dim
recurrent_state = self.delta_cache[:, conv_flat:conv_flat + ssm_flat].reshape(B, self.num_v_heads, self.head_v_dim, self.head_v_dim)
recurrent_state = recurrent_state * alpha
recurrent_state = self.recurrent_state * alpha
recurrent_state = recurrent_state + ((v - recurrent_state@k) * beta)@k.transpose(-1, -2)
new_cache = conv_window[:, 1:, :].reshape(B, -1).cat(recurrent_state.reshape(B, -1), dim=-1).contiguous()
assigned = self.delta_cache.uop.after(self.delta_cache.uop.store(new_cache.cast(self.delta_cache.dtype).uop))
cache_tensor = Tensor(assigned, device=self.delta_cache.device)
# final
final_state = cache_tensor[:, conv_flat:conv_flat + ssm_flat].reshape(B, self.num_v_heads, self.head_v_dim, self.head_v_dim)
core_attn_out = self.ssm_norm((final_state@q).squeeze(-1).reshape(B, 1, self.num_v_heads, self.head_v_dim))
# store the updated state
conv_state_store = self.conv_state.uop.store(conv_window[:, 1:, :].cast(self.conv_state.dtype).uop)
recurrent_state_store = self.recurrent_state.uop.store(recurrent_state.cast(self.recurrent_state.dtype).uop)
recurrent_state = Tensor(self.recurrent_state.uop.after(recurrent_state_store, conv_state_store))
# output
core_attn_out = self.ssm_norm((recurrent_state@q).squeeze(-1).reshape(B, 1, self.num_v_heads, self.head_v_dim))
return self.ssm_out((core_attn_out * out_gate.silu()).reshape(B, 1, -1).cast(x.dtype))
# recurrent state can't be partially reused after divergence, force a full rebuild
def _state_reset_ops(self): return [self.delta_cache.assign(Tensor.zeros_like(self.delta_cache))] if hasattr(self, "delta_cache") else []
def _state_reset_ops(self):
return [self.conv_state.assign(Tensor.zeros_like(self.conv_state)),
self.recurrent_state.assign(Tensor.zeros_like(self.recurrent_state))] if hasattr(self, "conv_state") else []
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return 0 if prefix_len != cached_len else prefix_len
def _init_state(self, x):
if not hasattr(self, "delta_cache"):
conv_flat = (self.ssm_conv_kernel - 1) * self.conv_channels
ssm_flat = self.num_v_heads * self.head_v_dim * self.head_v_dim
self.delta_cache = Tensor.zeros(x.shape[0], conv_flat + ssm_flat, device=x.device).clone()
if not hasattr(self, "conv_state"):
self.conv_state = Tensor.zeros(x.shape[0], self.ssm_conv_kernel-1, self.conv_channels, device=x.device).clone()
self.recurrent_state = Tensor.zeros(x.shape[0], self.num_v_heads, self.head_v_dim, self.head_v_dim, device=x.device).clone()
class Transformer:
def __init__(self, config:TransformerConfig):