mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
simplify GatedDeltaNetBlock using two state tensors (#15704)
* test double after * simpler ssm * no double test
This commit is contained in:
@@ -70,11 +70,14 @@ class TestGatedDeltaNetBlock(unittest.TestCase):
|
||||
return block._attention(x_norm, start_pos).realize().numpy()
|
||||
|
||||
def _cache_views(self, block:GatedDeltaNetBlock) -> tuple[np.ndarray, np.ndarray]:
|
||||
conv_flat = (block.ssm_conv_kernel - 1) * block.conv_channels
|
||||
cache = block.delta_cache.numpy()
|
||||
conv_state = cache[:, :conv_flat].reshape(cache.shape[0], block.ssm_conv_kernel - 1, block.conv_channels)
|
||||
recurrent_state = cache[:, conv_flat:].reshape(cache.shape[0], block.num_v_heads, block.head_v_dim, block.head_v_dim)
|
||||
return conv_state, recurrent_state
|
||||
if hasattr(block, 'conv_state'):
|
||||
return block.conv_state.numpy(), block.recurrent_state.numpy()
|
||||
else:
|
||||
conv_flat = (block.ssm_conv_kernel - 1) * block.conv_channels
|
||||
cache = block.delta_cache.numpy()
|
||||
conv_state = cache[:, :conv_flat].reshape(cache.shape[0], block.ssm_conv_kernel - 1, block.conv_channels)
|
||||
recurrent_state = cache[:, conv_flat:].reshape(cache.shape[0], block.num_v_heads, block.head_v_dim, block.head_v_dim)
|
||||
return conv_state, recurrent_state
|
||||
|
||||
def _linear_np(self, x:np.ndarray, weight:np.ndarray) -> np.ndarray:
|
||||
return x.astype(np.float32) @ weight.T.astype(np.float32)
|
||||
|
||||
@@ -321,13 +321,9 @@ class GatedDeltaNetBlock(FFNBlock):
|
||||
beta = self.ssm_beta(x).sigmoid().reshape(B, self.num_v_heads, 1, 1)
|
||||
alpha = ((self.ssm_alpha(x).float() + self.ssm_dt["bias"]).softplus() * self.ssm_a).reshape(B, self.num_v_heads, 1, 1).exp()
|
||||
|
||||
# conv
|
||||
conv_flat = (self.ssm_conv_kernel - 1) * self.conv_channels
|
||||
conv_state = self.delta_cache[:, :conv_flat].reshape(B, self.ssm_conv_kernel - 1, self.conv_channels)
|
||||
conv_window = conv_state.cat(self.attn_qkv(x), dim=1)
|
||||
# qkv conv
|
||||
conv_window = self.conv_state.cat(self.attn_qkv(x), dim=1)
|
||||
conv_out = (conv_window * self.ssm_conv1d["weight"].T.unsqueeze(0)).sum(1).silu()
|
||||
|
||||
# qkv
|
||||
q, k, v = conv_out.split([self.q_dim, self.q_dim, self.conv_channels - 2*self.q_dim], dim=-1)
|
||||
q = q.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1)
|
||||
k = k.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1)
|
||||
@@ -335,28 +331,28 @@ class GatedDeltaNetBlock(FFNBlock):
|
||||
q, k, v = q.mul(self.head_k_dim**-0.5).unsqueeze(-1), k.unsqueeze(-1), v.unsqueeze(-1)
|
||||
|
||||
# recurrent
|
||||
ssm_flat = self.num_v_heads * self.head_v_dim * self.head_v_dim
|
||||
recurrent_state = self.delta_cache[:, conv_flat:conv_flat + ssm_flat].reshape(B, self.num_v_heads, self.head_v_dim, self.head_v_dim)
|
||||
recurrent_state = recurrent_state * alpha
|
||||
recurrent_state = self.recurrent_state * alpha
|
||||
recurrent_state = recurrent_state + ((v - recurrent_state@k) * beta)@k.transpose(-1, -2)
|
||||
new_cache = conv_window[:, 1:, :].reshape(B, -1).cat(recurrent_state.reshape(B, -1), dim=-1).contiguous()
|
||||
assigned = self.delta_cache.uop.after(self.delta_cache.uop.store(new_cache.cast(self.delta_cache.dtype).uop))
|
||||
cache_tensor = Tensor(assigned, device=self.delta_cache.device)
|
||||
|
||||
# final
|
||||
final_state = cache_tensor[:, conv_flat:conv_flat + ssm_flat].reshape(B, self.num_v_heads, self.head_v_dim, self.head_v_dim)
|
||||
core_attn_out = self.ssm_norm((final_state@q).squeeze(-1).reshape(B, 1, self.num_v_heads, self.head_v_dim))
|
||||
# store the updated state
|
||||
conv_state_store = self.conv_state.uop.store(conv_window[:, 1:, :].cast(self.conv_state.dtype).uop)
|
||||
recurrent_state_store = self.recurrent_state.uop.store(recurrent_state.cast(self.recurrent_state.dtype).uop)
|
||||
recurrent_state = Tensor(self.recurrent_state.uop.after(recurrent_state_store, conv_state_store))
|
||||
|
||||
# output
|
||||
core_attn_out = self.ssm_norm((recurrent_state@q).squeeze(-1).reshape(B, 1, self.num_v_heads, self.head_v_dim))
|
||||
return self.ssm_out((core_attn_out * out_gate.silu()).reshape(B, 1, -1).cast(x.dtype))
|
||||
|
||||
# recurrent state can't be partially reused after divergence, force a full rebuild
|
||||
def _state_reset_ops(self): return [self.delta_cache.assign(Tensor.zeros_like(self.delta_cache))] if hasattr(self, "delta_cache") else []
|
||||
def _state_reset_ops(self):
|
||||
return [self.conv_state.assign(Tensor.zeros_like(self.conv_state)),
|
||||
self.recurrent_state.assign(Tensor.zeros_like(self.recurrent_state))] if hasattr(self, "conv_state") else []
|
||||
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return 0 if prefix_len != cached_len else prefix_len
|
||||
|
||||
def _init_state(self, x):
|
||||
if not hasattr(self, "delta_cache"):
|
||||
conv_flat = (self.ssm_conv_kernel - 1) * self.conv_channels
|
||||
ssm_flat = self.num_v_heads * self.head_v_dim * self.head_v_dim
|
||||
self.delta_cache = Tensor.zeros(x.shape[0], conv_flat + ssm_flat, device=x.device).clone()
|
||||
if not hasattr(self, "conv_state"):
|
||||
self.conv_state = Tensor.zeros(x.shape[0], self.ssm_conv_kernel-1, self.conv_channels, device=x.device).clone()
|
||||
self.recurrent_state = Tensor.zeros(x.shape[0], self.num_v_heads, self.head_v_dim, self.head_v_dim, device=x.device).clone()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, config:TransformerConfig):
|
||||
|
||||
Reference in New Issue
Block a user