Mamba Implementation (#3456)

* first commit

* state back to orig

* mamba comparisions

* rm file

* rename file

* use Tensor.einsum and mke default model 370M

* Cleaned code and made a comparision test

* Simplyfy pull request. Only has 1 mamba implementation now.

* Update prompt

* rm whitespaces

* last space

* remove Einops dependency

* rm unused code

* add tests

* rm print statement

* rm imports

* skip CLANG

* Update skipIf description

* skip model test in CI and add CLANG fix

* rm Device import

* don't be stupid

* Fix conv assign

When the prompt is too short, the logic for conv_state assign messes up. This can be fixed when padding the tokenized array to min length of 4. I padded using the empty string token, but idk if proper practice is to use the PAD token

* fix p1

* temp

* fix jit import

---------

Co-authored-by: schlimeszn <schlimeszn@gmail.com>
Co-authored-by: reddyn <nikidsniper@gmail.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
reddyn12
2024-03-28 20:49:12 -04:00
committed by GitHub
parent d085837179
commit 9b5e15db6e
3 changed files with 495 additions and 0 deletions

1
.gitignore vendored
View File

@@ -28,6 +28,7 @@ extra/datasets/kits19/
extra/datasets/squad/
extra/datasets/img_align_celeba*
extra/datasets/open-images-v6-mlperf
extra/datasets/open-images-v6TEST
extra/datasets/kits/
extra/datasets/COCO/
extra/datasets/audio*

470
examples/mamba.py Normal file
View File

@@ -0,0 +1,470 @@
import os, sys, math, argparse
from tqdm import tqdm
sys.path.append(os.getcwd())
from typing import Any, Optional
from dataclasses import dataclass, field
import time
from tinygrad import Tensor, dtypes, nn
from tinygrad.engine.jit import TinyJit
from tinygrad.helpers import fetch
from tinygrad.nn.state import get_state_dict, load_state_dict, torch_load
from extra.models.llama import RMSNorm
from transformers import AutoTokenizer
# from einops import rearrange, repeat
import numpy as np
MODELS = {
"130m": {
"dim": 768,
"n_layers": 24,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
"370m": {
"dim": 1024,
"n_layers": 48,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
"790m": {
"dim": 1536,
"n_layers": 48,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
"1.4b": {
"dim": 2048,
"n_layer": 48,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
"2.8b": {
"dim": 2560,
"n_layer": 64,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
}
def fetch_weights(model_name: str):
if model_name not in MODELS.keys():
raise Exception(f"Requested unknown mamba model: {model_name}")
downloaded = fetch(
f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true"
)
weights = torch_load(downloaded)
return weights
def selective_scan_ref(
u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = delta.softplus()
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = len(B.shape) >= 3
is_variable_C = len(C.shape) >= 3
x = Tensor.zeros(batch, dim, dstate)
ys = []
deltaA = Tensor.einsum("bdl,dn->bdln", delta, A).exp()
if not is_variable_B:
deltaB_u = Tensor.einsum("bdl,dn,bdl->bdln", delta, B, u)
else:
if len(B.shape) == 3:
deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u)
else:
B = B.repeat((1,dim//B.shape[1],1,1))
deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
if is_variable_C and len(C.shape) == 4:
C = C.repeat((1,dim//C.shape[1],1,1))
last_state = None
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = Tensor.einsum("bdn,dn->bd", x, C)
else:
if len(C.shape) == 3:
y = Tensor.einsum("bdn,bn->bd", x, C[:, :, i])
else:
y = Tensor.einsum("bdn,bdn->bd", x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
ys.append(y)
y = Tensor.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * D.reshape((D.numel(), 1))
if z is not None:
out = out * z.silu()
return out if not return_last_state else (out, last_state)
class MambaMixer:
def __init__(
self,
dim,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
layer_idx=None,
):
self.dim = dim
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.dim)
self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank
self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
)
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std)
elif dt_init == "random":
self.dt_proj.weight = Tensor.uniform(
self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std
)
else:
raise NotImplementedError
dt = (
(
Tensor.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
.exp()
.maximum(dt_init_floor)
)
inv_dt = (
dt + (-((-dt).exp() - Tensor.ones(*dt.shape))).log()
)
self.dt_proj.bias.assign(inv_dt)
# S4D real initialization
self.A_log = (
Tensor.arange(1, self.d_state + 1).repeat([self.d_inner, 1]).contiguous().log()
)
# D "skip" parameter
self.D = Tensor.ones(self.d_inner) # Keep in fp32
self.out_proj = nn.Linear(self.d_inner, self.dim, bias=bias)
def __call__(self, hidden_states: Tensor, inference_params=None):
batch, seqlen, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states[:, -1:, :], conv_state, ssm_state)
return out
xz = self.in_proj.weight @ hidden_states.permute(2,0,1).reshape(hidden_states.shape[2],hidden_states.shape[1]*hidden_states.shape[0])
xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
if self.in_proj.bias is not None:
xz = xz + self.in_proj.bias.reshape((self.in_proj.bias.numel(), 1))
A = -self.A_log.exp()
x, z = xz.chunk(2, dim=1)
# Compute short convolution
if conv_state is not None:
conv_state.assign(x[:, :, -self.d_conv :]) # Update state (B D W)
x = self.conv1d(x)[..., :seqlen].swish()
x_dbl = self.x_proj(x.permute(0,2,1).reshape(x.shape[0]*x.shape[2], x.shape[1]))
dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.T
dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1).contiguous()
C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1).contiguous()
y = selective_scan_ref( # TODO: actually implement selective_scan_fn
x,
dt,
A,
B,
C,
self.D,
z=z,
delta_bias=self.dt_proj.bias,
delta_softplus=True,
return_last_state=ssm_state is not None,
)
if ssm_state is not None:
y, last_state = y
ssm_state.assign(last_state)
y = y.permute(0,2,1)
out = self.out_proj(y)
return out
def step(self, hidden_states, conv_state, ssm_state):
assert (
hidden_states.shape[1] == 1
), f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step
conv_state.assign(conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1))
x = (conv_state * self.conv1d.weight.reshape(self.conv1d.weight.shape[0],self.conv1d.weight.shape[2])).sum(-1)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = x.swish()
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt = x_db[:, : self.dt_rank]
B = x_db[:, self.dt_rank : (self.dt_rank + self.d_state)]
C = x_db[:, (self.dt_rank + self.d_state) :]
# Don't add dt_bias here
dt = self.dt_proj.weight @ dt.T
A = -self.A_log.exp()
# SSM step
dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus()
# TODO: Tensor.einsum?
dA = Tensor.einsum("db,dn->bdn", dt, A).exp()
dB = Tensor.einsum("db,bn->bdn", dt, B)
ssm_state.assign(ssm_state * dA + x.reshape(x.shape[0],x.shape[1], 1) * dB)
y = Tensor.einsum("bdn,bn->bd", ssm_state, C)
y = y + self.D * x
y = y * z.swish() # (B D)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def _get_states_from_cache(
self, inference_params, batch_size, initialize_states=False
):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
batch_shape = (batch_size,)
conv_state = Tensor.zeros(
batch_size, self.dim * self.expand, self.d_conv
).contiguous().realize()
ssm_state = Tensor.zeros(
batch_size, self.dim * self.expand, self.d_state
).realize()
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
return conv_state, ssm_state
class MambaBlock:
def __init__(
self,
dim: int,
norm_eps: float = 1e-5,
rms_norm: bool = True,
layer_idx: Optional[int] = None,
):
self.mixer = MambaMixer(dim, layer_idx=layer_idx)
if rms_norm:
self.norm = RMSNorm(dim, norm_eps)
else:
raise NotImplementedError
def __call__(
self,
hidden_states: Tensor,
residual: Optional[Tensor] = None,
inference_params=None,
):
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
return hidden_states, residual
class MambaBackbone:
def __init__(
self,
dim: int,
n_layers: int,
vocab_size: int,
rms_norm: bool = True,
norm_eps: float = 1e-5,
):
self.embedding = nn.Embedding(vocab_size, dim)
self.layers = [
MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)
]
if rms_norm:
self.norm_f = RMSNorm(dim, norm_eps)
def __call__(self, input_ids: Tensor, inference_params=None) -> Any:
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
)
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual)
return hidden_states
class Mamba:
def __init__(
self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1
):
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
self.backbone = MambaBackbone(dim, n_layers, vocab_size)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
self.forward_jit = TinyJit(self.forward)
def forward(self, input_ids, inference_params, num_last_tokens):
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
return self.lm_head(hidden_states).realize()
def __call__(self, input_ids, inference_params=None, num_last_tokens=0, jit=True):
if inference_params is None:
return self.forward(input_ids, inference_params, num_last_tokens)
if jit and inference_params.seqlen_offset > 0:
return self.forward_jit(input_ids, inference_params, num_last_tokens)
else:
return self.forward(input_ids, inference_params, num_last_tokens)
@staticmethod
def from_pretrained(model_name: str):
weights = fetch_weights(model_name)
model = Mamba(**MODELS[model_name])
load_state_dict(model, weights)
return model
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
seqlen_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
lengths_per_sample: Optional[Tensor] = None
def reset(self, max_seqlen, max_batch_size):
self.max_seqlen = max_seqlen
self.max_batch_size = max_batch_size
self.seqlen_offset = 0
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
def generate(model,
tokenizer,
prompt: str,
n_tokens_to_gen: int = 10,
sample: bool = False,
top_k: int = None):
tks = tokenizer(prompt)["input_ids"]
while(len(tks)<4):tks = [50279] + tks
temperature = 0.5
start_pos = 0
inference_params = InferenceParams(max_seqlen=1, max_batch_size=1, seqlen_offset=0)
for i in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
logits = model(Tensor([tks[start_pos:]]), inference_params, start_pos, jit=False)
inference_params.seqlen_offset = len(tks)
tok = (logits[:, -1, :]).softmax().argmax(axis=-1).item()
start_pos = len(tks)
tks.append(tok)
output_completions = ''.join([tokenizer.decode(output) for output in tks])
return output_completions
if __name__ == "__main__":
TORCHOUTPUT = '''Why is gravity \nso important?\nBecause it's the only'''
parser = argparse.ArgumentParser(
description="Run Mamba in tinygrad",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--prompt", type=str, default='Why is gravity ', help="Prompt for LLM completion"
)
parser.add_argument(
"--size",
type=str,
default="370m",
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]",
)
parser.add_argument(
"--n_tokens",
type=int,
default=10,
help="Number of tokens to generate",
)
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = Mamba.from_pretrained(args.size)
prompt = args.prompt
num_toks = args.n_tokens
s = time.time()
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks)
print(tinyoutput)
print('TIME: ', time.time() - s)
print('Outputs Match:', tinyoutput == TORCHOUTPUT)

24
test/models/test_mamba.py Normal file
View File

@@ -0,0 +1,24 @@
import unittest
from tinygrad.helpers import CI
from examples.mamba import Mamba, generate
from transformers import AutoTokenizer
PROMPT = 'Why is gravity '
TOKENIZER = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
@unittest.skipIf(CI, "model is slow for CI")
class TestMamba(unittest.TestCase):
def test_mamba_130M(self):
OUT_130M = '''Why is gravity \nnot a good idea?\n\nA:'''
model = Mamba.from_pretrained('130m')
tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
self.assertEqual(OUT_130M, tinyoutput)
del model
def test_mamba_370M(self):
OUT_370M = '''Why is gravity \nso important?\nBecause it's the only'''
model = Mamba.from_pretrained('370m')
tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
self.assertEqual(OUT_370M, tinyoutput)
del model
if __name__ == '__main__':
unittest.main()