mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Mamba Implementation (#3456)
* first commit * state back to orig * mamba comparisions * rm file * rename file * use Tensor.einsum and mke default model 370M * Cleaned code and made a comparision test * Simplyfy pull request. Only has 1 mamba implementation now. * Update prompt * rm whitespaces * last space * remove Einops dependency * rm unused code * add tests * rm print statement * rm imports * skip CLANG * Update skipIf description * skip model test in CI and add CLANG fix * rm Device import * don't be stupid * Fix conv assign When the prompt is too short, the logic for conv_state assign messes up. This can be fixed when padding the tokenized array to min length of 4. I padded using the empty string token, but idk if proper practice is to use the PAD token * fix p1 * temp * fix jit import --------- Co-authored-by: schlimeszn <schlimeszn@gmail.com> Co-authored-by: reddyn <nikidsniper@gmail.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -28,6 +28,7 @@ extra/datasets/kits19/
|
||||
extra/datasets/squad/
|
||||
extra/datasets/img_align_celeba*
|
||||
extra/datasets/open-images-v6-mlperf
|
||||
extra/datasets/open-images-v6TEST
|
||||
extra/datasets/kits/
|
||||
extra/datasets/COCO/
|
||||
extra/datasets/audio*
|
||||
|
||||
470
examples/mamba.py
Normal file
470
examples/mamba.py
Normal file
@@ -0,0 +1,470 @@
|
||||
import os, sys, math, argparse
|
||||
from tqdm import tqdm
|
||||
sys.path.append(os.getcwd())
|
||||
from typing import Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
import time
|
||||
from tinygrad import Tensor, dtypes, nn
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.helpers import fetch
|
||||
from tinygrad.nn.state import get_state_dict, load_state_dict, torch_load
|
||||
|
||||
from extra.models.llama import RMSNorm
|
||||
from transformers import AutoTokenizer
|
||||
# from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
|
||||
MODELS = {
|
||||
"130m": {
|
||||
"dim": 768,
|
||||
"n_layers": 24,
|
||||
"vocab_size": 50277,
|
||||
"pad_vocab_size_multiple": 8,
|
||||
},
|
||||
"370m": {
|
||||
"dim": 1024,
|
||||
"n_layers": 48,
|
||||
"vocab_size": 50277,
|
||||
"pad_vocab_size_multiple": 8,
|
||||
},
|
||||
"790m": {
|
||||
"dim": 1536,
|
||||
"n_layers": 48,
|
||||
"vocab_size": 50277,
|
||||
"pad_vocab_size_multiple": 8,
|
||||
},
|
||||
"1.4b": {
|
||||
"dim": 2048,
|
||||
"n_layer": 48,
|
||||
"vocab_size": 50277,
|
||||
"pad_vocab_size_multiple": 8,
|
||||
},
|
||||
"2.8b": {
|
||||
"dim": 2560,
|
||||
"n_layer": 64,
|
||||
"vocab_size": 50277,
|
||||
"pad_vocab_size_multiple": 8,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def fetch_weights(model_name: str):
|
||||
if model_name not in MODELS.keys():
|
||||
raise Exception(f"Requested unknown mamba model: {model_name}")
|
||||
downloaded = fetch(
|
||||
f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true"
|
||||
)
|
||||
weights = torch_load(downloaded)
|
||||
return weights
|
||||
|
||||
|
||||
def selective_scan_ref(
|
||||
u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
return_last_state=False,
|
||||
):
|
||||
"""
|
||||
u: r(B D L)
|
||||
delta: r(B D L)
|
||||
A: c(D N) or r(D N)
|
||||
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
D: r(D)
|
||||
z: r(B D L)
|
||||
delta_bias: r(D), fp32
|
||||
|
||||
out: r(B D L)
|
||||
last_state (optional): r(B D dstate) or c(B D dstate)
|
||||
"""
|
||||
dtype_in = u.dtype
|
||||
u = u.float()
|
||||
delta = delta.float()
|
||||
if delta_bias is not None:
|
||||
delta = delta + delta_bias[..., None].float()
|
||||
if delta_softplus:
|
||||
delta = delta.softplus()
|
||||
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
||||
is_variable_B = len(B.shape) >= 3
|
||||
is_variable_C = len(C.shape) >= 3
|
||||
x = Tensor.zeros(batch, dim, dstate)
|
||||
ys = []
|
||||
deltaA = Tensor.einsum("bdl,dn->bdln", delta, A).exp()
|
||||
if not is_variable_B:
|
||||
deltaB_u = Tensor.einsum("bdl,dn,bdl->bdln", delta, B, u)
|
||||
else:
|
||||
if len(B.shape) == 3:
|
||||
deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
||||
else:
|
||||
B = B.repeat((1,dim//B.shape[1],1,1))
|
||||
deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
||||
if is_variable_C and len(C.shape) == 4:
|
||||
C = C.repeat((1,dim//C.shape[1],1,1))
|
||||
last_state = None
|
||||
for i in range(u.shape[2]):
|
||||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||||
if not is_variable_C:
|
||||
y = Tensor.einsum("bdn,dn->bd", x, C)
|
||||
else:
|
||||
if len(C.shape) == 3:
|
||||
y = Tensor.einsum("bdn,bn->bd", x, C[:, :, i])
|
||||
else:
|
||||
y = Tensor.einsum("bdn,bdn->bd", x, C[:, :, :, i])
|
||||
if i == u.shape[2] - 1:
|
||||
last_state = x
|
||||
ys.append(y)
|
||||
y = Tensor.stack(ys, dim=2) # (batch dim L)
|
||||
out = y if D is None else y + u * D.reshape((D.numel(), 1))
|
||||
if z is not None:
|
||||
out = out * z.silu()
|
||||
return out if not return_last_state else (out, last_state)
|
||||
|
||||
|
||||
class MambaMixer:
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
d_state=16,
|
||||
d_conv=4,
|
||||
expand=2,
|
||||
dt_rank="auto",
|
||||
dt_min=0.001,
|
||||
dt_max=0.1,
|
||||
dt_init="random",
|
||||
dt_scale=1.0,
|
||||
dt_init_floor=1e-4,
|
||||
conv_bias=True,
|
||||
bias=False,
|
||||
layer_idx=None,
|
||||
):
|
||||
self.dim = dim
|
||||
self.d_state = d_state
|
||||
self.d_conv = d_conv
|
||||
self.expand = expand
|
||||
self.d_inner = int(self.expand * self.dim)
|
||||
self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias)
|
||||
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=self.d_inner,
|
||||
out_channels=self.d_inner,
|
||||
bias=conv_bias,
|
||||
kernel_size=d_conv,
|
||||
groups=self.d_inner,
|
||||
padding=d_conv - 1,
|
||||
)
|
||||
|
||||
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
|
||||
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
||||
|
||||
# Initialize special dt projection to preserve variance at initialization
|
||||
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
||||
if dt_init == "constant":
|
||||
self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std)
|
||||
elif dt_init == "random":
|
||||
self.dt_proj.weight = Tensor.uniform(
|
||||
self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dt = (
|
||||
(
|
||||
Tensor.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
|
||||
+ math.log(dt_min)
|
||||
)
|
||||
.exp()
|
||||
.maximum(dt_init_floor)
|
||||
)
|
||||
inv_dt = (
|
||||
dt + (-((-dt).exp() - Tensor.ones(*dt.shape))).log()
|
||||
)
|
||||
|
||||
self.dt_proj.bias.assign(inv_dt)
|
||||
|
||||
# S4D real initialization
|
||||
self.A_log = (
|
||||
Tensor.arange(1, self.d_state + 1).repeat([self.d_inner, 1]).contiguous().log()
|
||||
)
|
||||
|
||||
# D "skip" parameter
|
||||
self.D = Tensor.ones(self.d_inner) # Keep in fp32
|
||||
|
||||
self.out_proj = nn.Linear(self.d_inner, self.dim, bias=bias)
|
||||
|
||||
def __call__(self, hidden_states: Tensor, inference_params=None):
|
||||
batch, seqlen, dim = hidden_states.shape
|
||||
|
||||
conv_state, ssm_state = None, None
|
||||
if inference_params is not None:
|
||||
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
||||
if inference_params.seqlen_offset > 0:
|
||||
# The states are updated inplace
|
||||
out, _, _ = self.step(hidden_states[:, -1:, :], conv_state, ssm_state)
|
||||
return out
|
||||
|
||||
xz = self.in_proj.weight @ hidden_states.permute(2,0,1).reshape(hidden_states.shape[2],hidden_states.shape[1]*hidden_states.shape[0])
|
||||
xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
|
||||
|
||||
if self.in_proj.bias is not None:
|
||||
xz = xz + self.in_proj.bias.reshape((self.in_proj.bias.numel(), 1))
|
||||
|
||||
A = -self.A_log.exp()
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
# Compute short convolution
|
||||
if conv_state is not None:
|
||||
conv_state.assign(x[:, :, -self.d_conv :]) # Update state (B D W)
|
||||
x = self.conv1d(x)[..., :seqlen].swish()
|
||||
|
||||
x_dbl = self.x_proj(x.permute(0,2,1).reshape(x.shape[0]*x.shape[2], x.shape[1]))
|
||||
dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||
dt = self.dt_proj.weight @ dt.T
|
||||
dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
|
||||
B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1).contiguous()
|
||||
C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1).contiguous()
|
||||
|
||||
y = selective_scan_ref( # TODO: actually implement selective_scan_fn
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
self.D,
|
||||
z=z,
|
||||
delta_bias=self.dt_proj.bias,
|
||||
delta_softplus=True,
|
||||
return_last_state=ssm_state is not None,
|
||||
)
|
||||
|
||||
if ssm_state is not None:
|
||||
y, last_state = y
|
||||
ssm_state.assign(last_state)
|
||||
|
||||
y = y.permute(0,2,1)
|
||||
out = self.out_proj(y)
|
||||
|
||||
return out
|
||||
|
||||
def step(self, hidden_states, conv_state, ssm_state):
|
||||
assert (
|
||||
hidden_states.shape[1] == 1
|
||||
), f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
|
||||
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
||||
x, z = xz.chunk(2, dim=-1) # (B D)
|
||||
|
||||
# Conv step
|
||||
conv_state.assign(conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1))
|
||||
x = (conv_state * self.conv1d.weight.reshape(self.conv1d.weight.shape[0],self.conv1d.weight.shape[2])).sum(-1)
|
||||
if self.conv1d.bias is not None:
|
||||
x = x + self.conv1d.bias
|
||||
x = x.swish()
|
||||
|
||||
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
||||
dt = x_db[:, : self.dt_rank]
|
||||
B = x_db[:, self.dt_rank : (self.dt_rank + self.d_state)]
|
||||
C = x_db[:, (self.dt_rank + self.d_state) :]
|
||||
# Don't add dt_bias here
|
||||
dt = self.dt_proj.weight @ dt.T
|
||||
A = -self.A_log.exp()
|
||||
|
||||
|
||||
# SSM step
|
||||
dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus()
|
||||
# TODO: Tensor.einsum?
|
||||
dA = Tensor.einsum("db,dn->bdn", dt, A).exp()
|
||||
dB = Tensor.einsum("db,bn->bdn", dt, B)
|
||||
ssm_state.assign(ssm_state * dA + x.reshape(x.shape[0],x.shape[1], 1) * dB)
|
||||
y = Tensor.einsum("bdn,bn->bd", ssm_state, C)
|
||||
y = y + self.D * x
|
||||
y = y * z.swish() # (B D)
|
||||
|
||||
out = self.out_proj(y)
|
||||
return out.unsqueeze(1), conv_state, ssm_state
|
||||
|
||||
def _get_states_from_cache(
|
||||
self, inference_params, batch_size, initialize_states=False
|
||||
):
|
||||
assert self.layer_idx is not None
|
||||
if self.layer_idx not in inference_params.key_value_memory_dict:
|
||||
batch_shape = (batch_size,)
|
||||
conv_state = Tensor.zeros(
|
||||
batch_size, self.dim * self.expand, self.d_conv
|
||||
).contiguous().realize()
|
||||
ssm_state = Tensor.zeros(
|
||||
batch_size, self.dim * self.expand, self.d_state
|
||||
).realize()
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
||||
else:
|
||||
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
return conv_state, ssm_state
|
||||
|
||||
|
||||
class MambaBlock:
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
norm_eps: float = 1e-5,
|
||||
rms_norm: bool = True,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
self.mixer = MambaMixer(dim, layer_idx=layer_idx)
|
||||
if rms_norm:
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: Tensor,
|
||||
residual: Optional[Tensor] = None,
|
||||
inference_params=None,
|
||||
):
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
hidden_states = self.norm(residual)
|
||||
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class MambaBackbone:
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_layers: int,
|
||||
vocab_size: int,
|
||||
rms_norm: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
):
|
||||
self.embedding = nn.Embedding(vocab_size, dim)
|
||||
self.layers = [
|
||||
MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)
|
||||
]
|
||||
if rms_norm:
|
||||
self.norm_f = RMSNorm(dim, norm_eps)
|
||||
|
||||
def __call__(self, input_ids: Tensor, inference_params=None) -> Any:
|
||||
hidden_states = self.embedding(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
hidden_states, residual, inference_params=inference_params
|
||||
)
|
||||
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
hidden_states = self.norm_f(residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Mamba:
|
||||
def __init__(
|
||||
self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1
|
||||
):
|
||||
if vocab_size % pad_vocab_size_multiple != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
||||
|
||||
self.backbone = MambaBackbone(dim, n_layers, vocab_size)
|
||||
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
|
||||
|
||||
self.forward_jit = TinyJit(self.forward)
|
||||
|
||||
def forward(self, input_ids, inference_params, num_last_tokens):
|
||||
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
||||
if num_last_tokens > 0:
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
return self.lm_head(hidden_states).realize()
|
||||
|
||||
def __call__(self, input_ids, inference_params=None, num_last_tokens=0, jit=True):
|
||||
if inference_params is None:
|
||||
return self.forward(input_ids, inference_params, num_last_tokens)
|
||||
if jit and inference_params.seqlen_offset > 0:
|
||||
return self.forward_jit(input_ids, inference_params, num_last_tokens)
|
||||
else:
|
||||
return self.forward(input_ids, inference_params, num_last_tokens)
|
||||
@staticmethod
|
||||
def from_pretrained(model_name: str):
|
||||
weights = fetch_weights(model_name)
|
||||
model = Mamba(**MODELS[model_name])
|
||||
load_state_dict(model, weights)
|
||||
|
||||
return model
|
||||
|
||||
@dataclass
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference."""
|
||||
|
||||
max_seqlen: int
|
||||
max_batch_size: int
|
||||
seqlen_offset: int = 0
|
||||
batch_size_offset: int = 0
|
||||
key_value_memory_dict: dict = field(default_factory=dict)
|
||||
lengths_per_sample: Optional[Tensor] = None
|
||||
|
||||
def reset(self, max_seqlen, max_batch_size):
|
||||
self.max_seqlen = max_seqlen
|
||||
self.max_batch_size = max_batch_size
|
||||
self.seqlen_offset = 0
|
||||
if self.lengths_per_sample is not None:
|
||||
self.lengths_per_sample.zero_()
|
||||
|
||||
def generate(model,
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
n_tokens_to_gen: int = 10,
|
||||
sample: bool = False,
|
||||
top_k: int = None):
|
||||
tks = tokenizer(prompt)["input_ids"]
|
||||
while(len(tks)<4):tks = [50279] + tks
|
||||
temperature = 0.5
|
||||
start_pos = 0
|
||||
inference_params = InferenceParams(max_seqlen=1, max_batch_size=1, seqlen_offset=0)
|
||||
for i in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
|
||||
logits = model(Tensor([tks[start_pos:]]), inference_params, start_pos, jit=False)
|
||||
inference_params.seqlen_offset = len(tks)
|
||||
tok = (logits[:, -1, :]).softmax().argmax(axis=-1).item()
|
||||
start_pos = len(tks)
|
||||
tks.append(tok)
|
||||
output_completions = ''.join([tokenizer.decode(output) for output in tks])
|
||||
return output_completions
|
||||
|
||||
if __name__ == "__main__":
|
||||
TORCHOUTPUT = '''Why is gravity \nso important?\nBecause it's the only'''
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Mamba in tinygrad",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default='Why is gravity ', help="Prompt for LLM completion"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=str,
|
||||
default="370m",
|
||||
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_tokens",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of tokens to generate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
||||
model = Mamba.from_pretrained(args.size)
|
||||
prompt = args.prompt
|
||||
num_toks = args.n_tokens
|
||||
s = time.time()
|
||||
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks)
|
||||
print(tinyoutput)
|
||||
print('TIME: ', time.time() - s)
|
||||
print('Outputs Match:', tinyoutput == TORCHOUTPUT)
|
||||
24
test/models/test_mamba.py
Normal file
24
test/models/test_mamba.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import unittest
|
||||
from tinygrad.helpers import CI
|
||||
from examples.mamba import Mamba, generate
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
PROMPT = 'Why is gravity '
|
||||
TOKENIZER = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
||||
|
||||
@unittest.skipIf(CI, "model is slow for CI")
|
||||
class TestMamba(unittest.TestCase):
|
||||
def test_mamba_130M(self):
|
||||
OUT_130M = '''Why is gravity \nnot a good idea?\n\nA:'''
|
||||
model = Mamba.from_pretrained('130m')
|
||||
tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
|
||||
self.assertEqual(OUT_130M, tinyoutput)
|
||||
del model
|
||||
def test_mamba_370M(self):
|
||||
OUT_370M = '''Why is gravity \nso important?\nBecause it's the only'''
|
||||
model = Mamba.from_pretrained('370m')
|
||||
tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
|
||||
self.assertEqual(OUT_370M, tinyoutput)
|
||||
del model
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user