cleanup mamba (#4004)

make it read nicer and cleanup some movement methods and math simplification.
790m, 1.4b, 2.8b model does not really run.
sampling is not implemented.
jit is incorrect.
some deadcode / wrong code path and copied from torch stuff stuff.
This commit is contained in:
chenyu
2024-03-30 02:50:13 -04:00
committed by GitHub
parent f35f9d32f2
commit aa76d566c2

View File

@@ -1,62 +1,29 @@
import os, sys, math, argparse import os, sys, math, argparse, time
from tqdm import tqdm
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from typing import Any, Optional from typing import Any, Optional, Dict
from dataclasses import dataclass, field from dataclasses import dataclass, field
import time
from tinygrad import Tensor, dtypes, nn
from tinygrad.engine.jit import TinyJit
from tinygrad.helpers import fetch
from tinygrad.nn.state import get_state_dict, load_state_dict, torch_load
from tinygrad import Tensor, TinyJit, nn
from tinygrad.helpers import fetch
from tinygrad.nn.state import load_state_dict, torch_load
from extra.models.llama import RMSNorm from extra.models.llama import RMSNorm
from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
# from einops import rearrange, repeat
import numpy as np
MODELS = { MODELS = {
"130m": { "130m": {"dim": 768, "n_layers": 24, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
"dim": 768, "370m": {"dim": 1024, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
"n_layers": 24, "790m": {"dim": 1536, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
"vocab_size": 50277, "1.4b": {"dim": 2048, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
"pad_vocab_size_multiple": 8, "2.8b": {"dim": 2560, "n_layers": 64, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
},
"370m": {
"dim": 1024,
"n_layers": 48,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
"790m": {
"dim": 1536,
"n_layers": 48,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
"1.4b": {
"dim": 2048,
"n_layer": 48,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
"2.8b": {
"dim": 2560,
"n_layer": 64,
"vocab_size": 50277,
"pad_vocab_size_multiple": 8,
},
} }
def fetch_weights(model_name: str) -> Dict[str, Tensor]:
def fetch_weights(model_name: str): if model_name not in MODELS:
if model_name not in MODELS.keys(): raise ValueError(f"Requested unknown mamba model: {model_name}")
raise Exception(f"Requested unknown mamba model: {model_name}") downloaded = fetch(f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true")
downloaded = fetch( return torch_load(downloaded)
f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true"
)
weights = torch_load(downloaded)
return weights
def selective_scan_ref( def selective_scan_ref(
u, u,
@@ -83,7 +50,6 @@ def selective_scan_ref(
out: r(B D L) out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate) last_state (optional): r(B D dstate) or c(B D dstate)
""" """
dtype_in = u.dtype
u = u.float() u = u.float()
delta = delta.float() delta = delta.float()
if delta_bias is not None: if delta_bias is not None:
@@ -102,10 +68,10 @@ def selective_scan_ref(
if len(B.shape) == 3: if len(B.shape) == 3:
deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u) deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u)
else: else:
B = B.repeat((1,dim//B.shape[1],1,1)) B = B.repeat((1, dim // B.shape[1], 1, 1))
deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u) deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
if is_variable_C and len(C.shape) == 4: if is_variable_C and len(C.shape) == 4:
C = C.repeat((1,dim//C.shape[1],1,1)) C = C.repeat((1, dim // C.shape[1], 1, 1))
last_state = None last_state = None
for i in range(u.shape[2]): for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i] x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
@@ -120,12 +86,11 @@ def selective_scan_ref(
last_state = x last_state = x
ys.append(y) ys.append(y)
y = Tensor.stack(ys, dim=2) # (batch dim L) y = Tensor.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * D.reshape((D.numel(), 1)) out = y if D is None else y + u * D.reshape((-1, 1))
if z is not None: if z is not None:
out = out * z.silu() out = out * z.silu()
return out if not return_last_state else (out, last_state) return out if not return_last_state else (out, last_state)
class MambaMixer: class MambaMixer:
def __init__( def __init__(
self, self,
@@ -147,20 +112,14 @@ class MambaMixer:
self.d_state = d_state self.d_state = d_state
self.d_conv = d_conv self.d_conv = d_conv
self.expand = expand self.expand = expand
self.d_inner = int(self.expand * self.dim) self.d_inner = self.expand * self.dim
self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias) self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias)
self.conv1d = nn.Conv1d( self.conv1d = nn.Conv1d(in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias,
in_channels=self.d_inner, kernel_size=d_conv, groups=self.d_inner, padding=d_conv-1)
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
)
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False) self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
@@ -170,30 +129,17 @@ class MambaMixer:
if dt_init == "constant": if dt_init == "constant":
self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std) self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std)
elif dt_init == "random": elif dt_init == "random":
self.dt_proj.weight = Tensor.uniform( self.dt_proj.weight = Tensor.uniform(self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std)
self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std
)
else: else:
raise NotImplementedError raise NotImplementedError
dt = ( dt = Tensor.uniform(self.d_inner, low=math.log(dt_min), high=math.log(dt_max)).exp().maximum(dt_init_floor)
( inv_dt = dt + (1 - (-dt).exp()).log()
Tensor.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
.exp()
.maximum(dt_init_floor)
)
inv_dt = (
dt + (-((-dt).exp() - Tensor.ones(*dt.shape))).log()
)
self.dt_proj.bias.assign(inv_dt) self.dt_proj.bias.assign(inv_dt)
# S4D real initialization # S4D real initialization
self.A_log = ( self.A_log = Tensor.arange(1, self.d_state+1).repeat([self.d_inner, 1]).log()
Tensor.arange(1, self.d_state + 1).repeat([self.d_inner, 1]).contiguous().log()
)
# D "skip" parameter # D "skip" parameter
self.D = Tensor.ones(self.d_inner) # Keep in fp32 self.D = Tensor.ones(self.d_inner) # Keep in fp32
@@ -215,7 +161,7 @@ class MambaMixer:
xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2) xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
if self.in_proj.bias is not None: if self.in_proj.bias is not None:
xz = xz + self.in_proj.bias.reshape((self.in_proj.bias.numel(), 1)) xz = xz + self.in_proj.bias.reshape((-1, 1))
A = -self.A_log.exp() A = -self.A_log.exp()
x, z = xz.chunk(2, dim=1) x, z = xz.chunk(2, dim=1)
@@ -228,41 +174,28 @@ class MambaMixer:
dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.T dt = self.dt_proj.weight @ dt.T
dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2) dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1).contiguous() B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1)
C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1).contiguous() C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1)
y = selective_scan_ref( # TODO: actually implement selective_scan_fn
x,
dt,
A,
B,
C,
self.D,
z=z,
delta_bias=self.dt_proj.bias,
delta_softplus=True,
return_last_state=ssm_state is not None,
)
# TODO: actually implement selective_scan_fn
y = selective_scan_ref(x, dt, A, B, C, self.D, z=z, delta_bias=self.dt_proj.bias, delta_softplus=True,
return_last_state=ssm_state is not None)
if ssm_state is not None: if ssm_state is not None:
y, last_state = y y, last_state = y
ssm_state.assign(last_state) ssm_state.assign(last_state)
y = y.permute(0,2,1) y = y.permute(0,2,1)
out = self.out_proj(y) out = self.out_proj(y)
return out return out
def step(self, hidden_states, conv_state, ssm_state): def step(self, hidden_states: Tensor, conv_state: Tensor, ssm_state: Tensor):
assert ( assert hidden_states.shape[1] == 1, f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
hidden_states.shape[1] == 1
), f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D) x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step # Conv step
conv_state.assign(conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1)) conv_state.assign(conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1))
x = (conv_state * self.conv1d.weight.reshape(self.conv1d.weight.shape[0],self.conv1d.weight.shape[2])).sum(-1) x = (conv_state * self.conv1d.weight.squeeze(1)).sum(-1)
if self.conv1d.bias is not None: if self.conv1d.bias is not None:
x = x + self.conv1d.bias x = x + self.conv1d.bias
x = x.swish() x = x.swish()
@@ -275,13 +208,11 @@ class MambaMixer:
dt = self.dt_proj.weight @ dt.T dt = self.dt_proj.weight @ dt.T
A = -self.A_log.exp() A = -self.A_log.exp()
# SSM step # SSM step
dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus() dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus()
# TODO: Tensor.einsum?
dA = Tensor.einsum("db,dn->bdn", dt, A).exp() dA = Tensor.einsum("db,dn->bdn", dt, A).exp()
dB = Tensor.einsum("db,bn->bdn", dt, B) dB = Tensor.einsum("db,bn->bdn", dt, B)
ssm_state.assign(ssm_state * dA + x.reshape(x.shape[0],x.shape[1], 1) * dB) ssm_state.assign(ssm_state * dA + x.unsqueeze(-1) * dB)
y = Tensor.einsum("bdn,bn->bd", ssm_state, C) y = Tensor.einsum("bdn,bn->bd", ssm_state, C)
y = y + self.D * x y = y + self.D * x
y = y * z.swish() # (B D) y = y * z.swish() # (B D)
@@ -289,63 +220,34 @@ class MambaMixer:
out = self.out_proj(y) out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state return out.unsqueeze(1), conv_state, ssm_state
def _get_states_from_cache( def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
self, inference_params, batch_size, initialize_states=False
):
assert self.layer_idx is not None assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict: if self.layer_idx not in inference_params.key_value_memory_dict:
batch_shape = (batch_size,) conv_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_conv).contiguous().realize()
conv_state = Tensor.zeros( ssm_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_state).realize()
batch_size, self.dim * self.expand, self.d_conv
).contiguous().realize()
ssm_state = Tensor.zeros(
batch_size, self.dim * self.expand, self.d_state
).realize()
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else: else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
return conv_state, ssm_state return conv_state, ssm_state
class MambaBlock: class MambaBlock:
def __init__( def __init__(self, dim: int, norm_eps: float = 1e-5, rms_norm: bool = True, layer_idx: Optional[int] = None):
self,
dim: int,
norm_eps: float = 1e-5,
rms_norm: bool = True,
layer_idx: Optional[int] = None,
):
self.mixer = MambaMixer(dim, layer_idx=layer_idx) self.mixer = MambaMixer(dim, layer_idx=layer_idx)
if rms_norm: if rms_norm:
self.norm = RMSNorm(dim, norm_eps) self.norm = RMSNorm(dim, norm_eps)
else: else:
raise NotImplementedError raise NotImplementedError
def __call__( def __call__(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):
self,
hidden_states: Tensor,
residual: Optional[Tensor] = None,
inference_params=None,
):
residual = (hidden_states + residual) if residual is not None else hidden_states residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual) hidden_states = self.norm(residual)
hidden_states = self.mixer(hidden_states, inference_params=inference_params) hidden_states = self.mixer(hidden_states, inference_params=inference_params)
return hidden_states, residual return hidden_states, residual
class MambaBackbone: class MambaBackbone:
def __init__( def __init__(self, dim: int, n_layers: int, vocab_size: int, rms_norm: bool = True, norm_eps: float = 1e-5):
self,
dim: int,
n_layers: int,
vocab_size: int,
rms_norm: bool = True,
norm_eps: float = 1e-5,
):
self.embedding = nn.Embedding(vocab_size, dim) self.embedding = nn.Embedding(vocab_size, dim)
self.layers = [ self.layers = [MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)]
MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)
]
if rms_norm: if rms_norm:
self.norm_f = RMSNorm(dim, norm_eps) self.norm_f = RMSNorm(dim, norm_eps)
@@ -353,20 +255,15 @@ class MambaBackbone:
hidden_states = self.embedding(input_ids) hidden_states = self.embedding(input_ids)
residual = None residual = None
for layer in self.layers: for layer in self.layers:
hidden_states, residual = layer( hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
hidden_states, residual, inference_params=inference_params
)
residual = (hidden_states + residual) if residual is not None else hidden_states residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual) hidden_states = self.norm_f(residual)
return hidden_states return hidden_states
class Mamba: class Mamba:
def __init__( def __init__(self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1):
self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1
):
if vocab_size % pad_vocab_size_multiple != 0: if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
@@ -400,7 +297,6 @@ class Mamba:
class InferenceParams: class InferenceParams:
"""Inference parameters that are passed to the main model in order """Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.""" to efficienly calculate and store the context during inference."""
max_seqlen: int max_seqlen: int
max_batch_size: int max_batch_size: int
seqlen_offset: int = 0 seqlen_offset: int = 0
@@ -415,48 +311,29 @@ class InferenceParams:
if self.lengths_per_sample is not None: if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_() self.lengths_per_sample.zero_()
def generate(model, def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, sample: bool = False, top_k: int = None):
tokenizer,
prompt: str,
n_tokens_to_gen: int = 10,
sample: bool = False,
top_k: int = None):
tks = tokenizer(prompt)["input_ids"] tks = tokenizer(prompt)["input_ids"]
while(len(tks)<4):tks = [50279] + tks while len(tks) < 4:
tks = [50279] + tks
# TODO: sampling
temperature = 0.5 temperature = 0.5
start_pos = 0 start_pos = 0
inference_params = InferenceParams(max_seqlen=1, max_batch_size=1, seqlen_offset=0) inference_params = InferenceParams(max_seqlen=1, max_batch_size=1, seqlen_offset=0)
for i in tqdm(range(n_tokens_to_gen), desc="Speed Gen"): for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
logits = model(Tensor([tks[start_pos:]]), inference_params, start_pos, jit=False) logits = model(Tensor([tks[start_pos:]]), inference_params, start_pos, jit=False)
inference_params.seqlen_offset = len(tks) inference_params.seqlen_offset = len(tks)
tok = (logits[:, -1, :]).softmax().argmax(axis=-1).item() tok = logits[:, -1, :].argmax(axis=-1).item()
start_pos = len(tks) start_pos = len(tks)
tks.append(tok) tks.append(tok)
output_completions = ''.join([tokenizer.decode(output) for output in tks]) output_completions = ''.join([tokenizer.decode(output) for output in tks])
return output_completions return output_completions
if __name__ == "__main__": if __name__ == "__main__":
TORCHOUTPUT = '''Why is gravity \nso important?\nBecause it's the only''' parser = argparse.ArgumentParser(description="Run Mamba in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--prompt", type=str, default="Why is gravity ", help="Prompt for LLM completion")
parser = argparse.ArgumentParser( parser.add_argument("--size", type=str, default="370m",
description="Run Mamba in tinygrad", help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
formatter_class=argparse.ArgumentDefaultsHelpFormatter, parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
)
parser.add_argument(
"--prompt", type=str, default='Why is gravity ', help="Prompt for LLM completion"
)
parser.add_argument(
"--size",
type=str,
default="370m",
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]",
)
parser.add_argument(
"--n_tokens",
type=int,
default=10,
help="Number of tokens to generate",
)
args = parser.parse_args() args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
@@ -467,4 +344,5 @@ if __name__ == "__main__":
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks) tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks)
print(tinyoutput) print(tinyoutput)
print('TIME: ', time.time() - s) print('TIME: ', time.time() - s)
TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"
print('Outputs Match:', tinyoutput == TORCHOUTPUT) print('Outputs Match:', tinyoutput == TORCHOUTPUT)