cleanup mamba (#4004)

make it read nicer and cleanup some movement methods and math simplification.
790m, 1.4b, 2.8b model does not really run.
sampling is not implemented.
jit is incorrect.
some deadcode / wrong code path and copied from torch stuff stuff.
This commit is contained in:
chenyu
2024-03-30 02:50:13 -04:00
committed by GitHub
parent f35f9d32f2
commit aa76d566c2

View File

@@ -1,62 +1,29 @@
import os, sys, math, argparse
from tqdm import tqdm
import os, sys, math, argparse, time
sys.path.append(os.getcwd())
from typing import Any, Optional
from typing import Any, Optional, Dict
from dataclasses import dataclass, field
import time
from tinygrad import Tensor, dtypes, nn
from tinygrad.engine.jit import TinyJit
from tinygrad.helpers import fetch
from tinygrad.nn.state import get_state_dict, load_state_dict, torch_load
from tinygrad import Tensor, TinyJit, nn
from tinygrad.helpers import fetch
from tinygrad.nn.state import load_state_dict, torch_load
from extra.models.llama import RMSNorm
from tqdm import tqdm
from transformers import AutoTokenizer
# from einops import rearrange, repeat
import numpy as np
MODELS = {
"130m": {
"dim": 768,
"n_layers": 24,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
"370m": {
"dim": 1024,
"n_layers": 48,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
"790m": {
"dim": 1536,
"n_layers": 48,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
"1.4b": {
"dim": 2048,
"n_layer": 48,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
"2.8b": {
"dim": 2560,
"n_layer": 64,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
"130m": {"dim": 768, "n_layers": 24, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
"370m": {"dim": 1024, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
"790m": {"dim": 1536, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
"1.4b": {"dim": 2048, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
"2.8b": {"dim": 2560, "n_layers": 64, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
}
def fetch_weights(model_name: str):
if model_name not in MODELS.keys():
raise Exception(f"Requested unknown mamba model: {model_name}")
downloaded = fetch(
f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true"
)
weights = torch_load(downloaded)
return weights
def fetch_weights(model_name: str) -> Dict[str, Tensor]:
if model_name not in MODELS:
raise ValueError(f"Requested unknown mamba model: {model_name}")
downloaded = fetch(f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true")
return torch_load(downloaded)
def selective_scan_ref(
u,
@@ -83,7 +50,6 @@ def selective_scan_ref(
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
@@ -102,10 +68,10 @@ def selective_scan_ref(
if len(B.shape) == 3:
deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u)
else:
B = B.repeat((1,dim//B.shape[1],1,1))
B = B.repeat((1, dim // B.shape[1], 1, 1))
deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
if is_variable_C and len(C.shape) == 4:
C = C.repeat((1,dim//C.shape[1],1,1))
C = C.repeat((1, dim // C.shape[1], 1, 1))
last_state = None
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
@@ -120,12 +86,11 @@ def selective_scan_ref(
last_state = x
ys.append(y)
y = Tensor.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * D.reshape((D.numel(), 1))
out = y if D is None else y + u * D.reshape((-1, 1))
if z is not None:
out = out * z.silu()
return out if not return_last_state else (out, last_state)
class MambaMixer:
def __init__(
self,
@@ -147,20 +112,14 @@ class MambaMixer:
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.dim)
self.d_inner = self.expand * self.dim
self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank
self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
)
self.conv1d = nn.Conv1d(in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias,
kernel_size=d_conv, groups=self.d_inner, padding=d_conv-1)
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
@@ -170,30 +129,17 @@ class MambaMixer:
if dt_init == "constant":
self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std)
elif dt_init == "random":
self.dt_proj.weight = Tensor.uniform(
self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std
)
self.dt_proj.weight = Tensor.uniform(self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std)
else:
raise NotImplementedError
dt = (
(
Tensor.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
.exp()
.maximum(dt_init_floor)
)
inv_dt = (
dt + (-((-dt).exp() - Tensor.ones(*dt.shape))).log()
)
dt = Tensor.uniform(self.d_inner, low=math.log(dt_min), high=math.log(dt_max)).exp().maximum(dt_init_floor)
inv_dt = dt + (1 - (-dt).exp()).log()
self.dt_proj.bias.assign(inv_dt)
# S4D real initialization
self.A_log = (
Tensor.arange(1, self.d_state + 1).repeat([self.d_inner, 1]).contiguous().log()
)
self.A_log = Tensor.arange(1, self.d_state+1).repeat([self.d_inner, 1]).log()
# D "skip" parameter
self.D = Tensor.ones(self.d_inner) # Keep in fp32
@@ -215,7 +161,7 @@ class MambaMixer:
xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
if self.in_proj.bias is not None:
xz = xz + self.in_proj.bias.reshape((self.in_proj.bias.numel(), 1))
xz = xz + self.in_proj.bias.reshape((-1, 1))
A = -self.A_log.exp()
x, z = xz.chunk(2, dim=1)
@@ -228,41 +174,28 @@ class MambaMixer:
dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.T
dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1).contiguous()
C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1).contiguous()
y = selective_scan_ref( # TODO: actually implement selective_scan_fn
x,
dt,
A,
B,
C,
self.D,
z=z,
delta_bias=self.dt_proj.bias,
delta_softplus=True,
return_last_state=ssm_state is not None,
)
B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1)
C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1)
# TODO: actually implement selective_scan_fn
y = selective_scan_ref(x, dt, A, B, C, self.D, z=z, delta_bias=self.dt_proj.bias, delta_softplus=True,
return_last_state=ssm_state is not None)
if ssm_state is not None:
y, last_state = y
ssm_state.assign(last_state)
y = y.permute(0,2,1)
out = self.out_proj(y)
return out
def step(self, hidden_states, conv_state, ssm_state):
assert (
hidden_states.shape[1] == 1
), f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
def step(self, hidden_states: Tensor, conv_state: Tensor, ssm_state: Tensor):
assert hidden_states.shape[1] == 1, f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step
conv_state.assign(conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1))
x = (conv_state * self.conv1d.weight.reshape(self.conv1d.weight.shape[0],self.conv1d.weight.shape[2])).sum(-1)
x = (conv_state * self.conv1d.weight.squeeze(1)).sum(-1)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = x.swish()
@@ -275,13 +208,11 @@ class MambaMixer:
dt = self.dt_proj.weight @ dt.T
A = -self.A_log.exp()
# SSM step
dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus()
# TODO: Tensor.einsum?
dA = Tensor.einsum("db,dn->bdn", dt, A).exp()
dB = Tensor.einsum("db,bn->bdn", dt, B)
ssm_state.assign(ssm_state * dA + x.reshape(x.shape[0],x.shape[1], 1) * dB)
ssm_state.assign(ssm_state * dA + x.unsqueeze(-1) * dB)
y = Tensor.einsum("bdn,bn->bd", ssm_state, C)
y = y + self.D * x
y = y * z.swish() # (B D)
@@ -289,63 +220,34 @@ class MambaMixer:
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def _get_states_from_cache(
self, inference_params, batch_size, initialize_states=False
):
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
batch_shape = (batch_size,)
conv_state = Tensor.zeros(
batch_size, self.dim * self.expand, self.d_conv
).contiguous().realize()
ssm_state = Tensor.zeros(
batch_size, self.dim * self.expand, self.d_state
).realize()
conv_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_conv).contiguous().realize()
ssm_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_state).realize()
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
return conv_state, ssm_state
class MambaBlock:
def __init__(
self,
dim: int,
norm_eps: float = 1e-5,
rms_norm: bool = True,
layer_idx: Optional[int] = None,
):
def __init__(self, dim: int, norm_eps: float = 1e-5, rms_norm: bool = True, layer_idx: Optional[int] = None):
self.mixer = MambaMixer(dim, layer_idx=layer_idx)
if rms_norm:
self.norm = RMSNorm(dim, norm_eps)
else:
raise NotImplementedError
def __call__(
self,
hidden_states: Tensor,
residual: Optional[Tensor] = None,
inference_params=None,
):
def __call__(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
return hidden_states, residual
class MambaBackbone:
def __init__(
self,
dim: int,
n_layers: int,
vocab_size: int,
rms_norm: bool = True,
norm_eps: float = 1e-5,
):
def __init__(self, dim: int, n_layers: int, vocab_size: int, rms_norm: bool = True, norm_eps: float = 1e-5):
self.embedding = nn.Embedding(vocab_size, dim)
self.layers = [
MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)
]
self.layers = [MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)]
if rms_norm:
self.norm_f = RMSNorm(dim, norm_eps)
@@ -353,20 +255,15 @@ class MambaBackbone:
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
)
hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual)
return hidden_states
class Mamba:
def __init__(
self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1
):
def __init__(self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1):
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
@@ -400,7 +297,6 @@ class Mamba:
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
seqlen_offset: int = 0
@@ -415,48 +311,29 @@ class InferenceParams:
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
def generate(model,
tokenizer,
prompt: str,
n_tokens_to_gen: int = 10,
sample: bool = False,
top_k: int = None):
def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, sample: bool = False, top_k: int = None):
tks = tokenizer(prompt)["input_ids"]
while(len(tks)<4):tks = [50279] + tks
while len(tks) < 4:
tks = [50279] + tks
# TODO: sampling
temperature = 0.5
start_pos = 0
inference_params = InferenceParams(max_seqlen=1, max_batch_size=1, seqlen_offset=0)
for i in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
logits = model(Tensor([tks[start_pos:]]), inference_params, start_pos, jit=False)
inference_params.seqlen_offset = len(tks)
tok = (logits[:, -1, :]).softmax().argmax(axis=-1).item()
tok = logits[:, -1, :].argmax(axis=-1).item()
start_pos = len(tks)
tks.append(tok)
output_completions = ''.join([tokenizer.decode(output) for output in tks])
return output_completions
if __name__ == "__main__":
TORCHOUTPUT = '''Why is gravity \nso important?\nBecause it's the only'''
parser = argparse.ArgumentParser(
description="Run Mamba in tinygrad",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--prompt", type=str, default='Why is gravity ', help="Prompt for LLM completion"
)
parser.add_argument(
"--size",
type=str,
default="370m",
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]",
)
parser.add_argument(
"--n_tokens",
type=int,
default=10,
help="Number of tokens to generate",
)
parser = argparse.ArgumentParser(description="Run Mamba in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--prompt", type=str, default="Why is gravity ", help="Prompt for LLM completion")
parser.add_argument("--size", type=str, default="370m",
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
@@ -467,4 +344,5 @@ if __name__ == "__main__":
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks)
print(tinyoutput)
print('TIME: ', time.time() - s)
TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"
print('Outputs Match:', tinyoutput == TORCHOUTPUT)