Files
tinygrad/examples/mamba.py
chenyu aa76d566c2 cleanup mamba (#4004)
make it read nicer and cleanup some movement methods and math simplification.
790m, 1.4b, 2.8b model does not really run.
sampling is not implemented.
jit is incorrect.
some deadcode / wrong code path and copied from torch stuff stuff.
2024-03-30 02:50:13 -04:00

349 lines
13 KiB
Python

import os, sys, math, argparse, time
sys.path.append(os.getcwd())
from typing import Any, Optional, Dict
from dataclasses import dataclass, field
from tinygrad import Tensor, TinyJit, nn
from tinygrad.helpers import fetch
from tinygrad.nn.state import load_state_dict, torch_load
from extra.models.llama import RMSNorm
from tqdm import tqdm
from transformers import AutoTokenizer
MODELS = {
"130m": {"dim": 768, "n_layers": 24, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
"370m": {"dim": 1024, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
"790m": {"dim": 1536, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
"1.4b": {"dim": 2048, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
"2.8b": {"dim": 2560, "n_layers": 64, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
}
def fetch_weights(model_name: str) -> Dict[str, Tensor]:
if model_name not in MODELS:
raise ValueError(f"Requested unknown mamba model: {model_name}")
downloaded = fetch(f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true")
return torch_load(downloaded)
def selective_scan_ref(
u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = delta.softplus()
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = len(B.shape) >= 3
is_variable_C = len(C.shape) >= 3
x = Tensor.zeros(batch, dim, dstate)
ys = []
deltaA = Tensor.einsum("bdl,dn->bdln", delta, A).exp()
if not is_variable_B:
deltaB_u = Tensor.einsum("bdl,dn,bdl->bdln", delta, B, u)
else:
if len(B.shape) == 3:
deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u)
else:
B = B.repeat((1, dim // B.shape[1], 1, 1))
deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
if is_variable_C and len(C.shape) == 4:
C = C.repeat((1, dim // C.shape[1], 1, 1))
last_state = None
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = Tensor.einsum("bdn,dn->bd", x, C)
else:
if len(C.shape) == 3:
y = Tensor.einsum("bdn,bn->bd", x, C[:, :, i])
else:
y = Tensor.einsum("bdn,bdn->bd", x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
ys.append(y)
y = Tensor.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * D.reshape((-1, 1))
if z is not None:
out = out * z.silu()
return out if not return_last_state else (out, last_state)
class MambaMixer:
def __init__(
self,
dim,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
layer_idx=None,
):
self.dim = dim
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = self.expand * self.dim
self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank
self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias)
self.conv1d = nn.Conv1d(in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias,
kernel_size=d_conv, groups=self.d_inner, padding=d_conv-1)
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std)
elif dt_init == "random":
self.dt_proj.weight = Tensor.uniform(self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std)
else:
raise NotImplementedError
dt = Tensor.uniform(self.d_inner, low=math.log(dt_min), high=math.log(dt_max)).exp().maximum(dt_init_floor)
inv_dt = dt + (1 - (-dt).exp()).log()
self.dt_proj.bias.assign(inv_dt)
# S4D real initialization
self.A_log = Tensor.arange(1, self.d_state+1).repeat([self.d_inner, 1]).log()
# D "skip" parameter
self.D = Tensor.ones(self.d_inner) # Keep in fp32
self.out_proj = nn.Linear(self.d_inner, self.dim, bias=bias)
def __call__(self, hidden_states: Tensor, inference_params=None):
batch, seqlen, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states[:, -1:, :], conv_state, ssm_state)
return out
xz = self.in_proj.weight @ hidden_states.permute(2,0,1).reshape(hidden_states.shape[2],hidden_states.shape[1]*hidden_states.shape[0])
xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
if self.in_proj.bias is not None:
xz = xz + self.in_proj.bias.reshape((-1, 1))
A = -self.A_log.exp()
x, z = xz.chunk(2, dim=1)
# Compute short convolution
if conv_state is not None:
conv_state.assign(x[:, :, -self.d_conv :]) # Update state (B D W)
x = self.conv1d(x)[..., :seqlen].swish()
x_dbl = self.x_proj(x.permute(0,2,1).reshape(x.shape[0]*x.shape[2], x.shape[1]))
dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.T
dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1)
C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1)
# TODO: actually implement selective_scan_fn
y = selective_scan_ref(x, dt, A, B, C, self.D, z=z, delta_bias=self.dt_proj.bias, delta_softplus=True,
return_last_state=ssm_state is not None)
if ssm_state is not None:
y, last_state = y
ssm_state.assign(last_state)
y = y.permute(0,2,1)
out = self.out_proj(y)
return out
def step(self, hidden_states: Tensor, conv_state: Tensor, ssm_state: Tensor):
assert hidden_states.shape[1] == 1, f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step
conv_state.assign(conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1))
x = (conv_state * self.conv1d.weight.squeeze(1)).sum(-1)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = x.swish()
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt = x_db[:, : self.dt_rank]
B = x_db[:, self.dt_rank : (self.dt_rank + self.d_state)]
C = x_db[:, (self.dt_rank + self.d_state) :]
# Don't add dt_bias here
dt = self.dt_proj.weight @ dt.T
A = -self.A_log.exp()
# SSM step
dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus()
dA = Tensor.einsum("db,dn->bdn", dt, A).exp()
dB = Tensor.einsum("db,bn->bdn", dt, B)
ssm_state.assign(ssm_state * dA + x.unsqueeze(-1) * dB)
y = Tensor.einsum("bdn,bn->bd", ssm_state, C)
y = y + self.D * x
y = y * z.swish() # (B D)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
conv_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_conv).contiguous().realize()
ssm_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_state).realize()
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
return conv_state, ssm_state
class MambaBlock:
def __init__(self, dim: int, norm_eps: float = 1e-5, rms_norm: bool = True, layer_idx: Optional[int] = None):
self.mixer = MambaMixer(dim, layer_idx=layer_idx)
if rms_norm:
self.norm = RMSNorm(dim, norm_eps)
else:
raise NotImplementedError
def __call__(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
return hidden_states, residual
class MambaBackbone:
def __init__(self, dim: int, n_layers: int, vocab_size: int, rms_norm: bool = True, norm_eps: float = 1e-5):
self.embedding = nn.Embedding(vocab_size, dim)
self.layers = [MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)]
if rms_norm:
self.norm_f = RMSNorm(dim, norm_eps)
def __call__(self, input_ids: Tensor, inference_params=None) -> Any:
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual)
return hidden_states
class Mamba:
def __init__(self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1):
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
self.backbone = MambaBackbone(dim, n_layers, vocab_size)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
self.forward_jit = TinyJit(self.forward)
def forward(self, input_ids, inference_params, num_last_tokens):
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
return self.lm_head(hidden_states).realize()
def __call__(self, input_ids, inference_params=None, num_last_tokens=0, jit=True):
if inference_params is None:
return self.forward(input_ids, inference_params, num_last_tokens)
if jit and inference_params.seqlen_offset > 0:
return self.forward_jit(input_ids, inference_params, num_last_tokens)
else:
return self.forward(input_ids, inference_params, num_last_tokens)
@staticmethod
def from_pretrained(model_name: str):
weights = fetch_weights(model_name)
model = Mamba(**MODELS[model_name])
load_state_dict(model, weights)
return model
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
seqlen_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
lengths_per_sample: Optional[Tensor] = None
def reset(self, max_seqlen, max_batch_size):
self.max_seqlen = max_seqlen
self.max_batch_size = max_batch_size
self.seqlen_offset = 0
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, sample: bool = False, top_k: int = None):
tks = tokenizer(prompt)["input_ids"]
while len(tks) < 4:
tks = [50279] + tks
# TODO: sampling
temperature = 0.5
start_pos = 0
inference_params = InferenceParams(max_seqlen=1, max_batch_size=1, seqlen_offset=0)
for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
logits = model(Tensor([tks[start_pos:]]), inference_params, start_pos, jit=False)
inference_params.seqlen_offset = len(tks)
tok = logits[:, -1, :].argmax(axis=-1).item()
start_pos = len(tks)
tks.append(tok)
output_completions = ''.join([tokenizer.decode(output) for output in tks])
return output_completions
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run Mamba in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--prompt", type=str, default="Why is gravity ", help="Prompt for LLM completion")
parser.add_argument("--size", type=str, default="370m",
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = Mamba.from_pretrained(args.size)
prompt = args.prompt
num_toks = args.n_tokens
s = time.time()
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks)
print(tinyoutput)
print('TIME: ', time.time() - s)
TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"
print('Outputs Match:', tinyoutput == TORCHOUTPUT)