From 9b5e15db6ec76868a6e179e3e68494a4d44d5191 Mon Sep 17 00:00:00 2001 From: reddyn12 <72528507+reddyn12@users.noreply.github.com> Date: Thu, 28 Mar 2024 20:49:12 -0400 Subject: [PATCH] Mamba Implementation (#3456) * first commit * state back to orig * mamba comparisions * rm file * rename file * use Tensor.einsum and mke default model 370M * Cleaned code and made a comparision test * Simplyfy pull request. Only has 1 mamba implementation now. * Update prompt * rm whitespaces * last space * remove Einops dependency * rm unused code * add tests * rm print statement * rm imports * skip CLANG * Update skipIf description * skip model test in CI and add CLANG fix * rm Device import * don't be stupid * Fix conv assign When the prompt is too short, the logic for conv_state assign messes up. This can be fixed when padding the tokenized array to min length of 4. I padded using the empty string token, but idk if proper practice is to use the PAD token * fix p1 * temp * fix jit import --------- Co-authored-by: schlimeszn Co-authored-by: reddyn Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- .gitignore | 1 + examples/mamba.py | 470 ++++++++++++++++++++++++++++++++++++++ test/models/test_mamba.py | 24 ++ 3 files changed, 495 insertions(+) create mode 100644 examples/mamba.py create mode 100644 test/models/test_mamba.py diff --git a/.gitignore b/.gitignore index 7f5d7c5362..272157cd78 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ extra/datasets/kits19/ extra/datasets/squad/ extra/datasets/img_align_celeba* extra/datasets/open-images-v6-mlperf +extra/datasets/open-images-v6TEST extra/datasets/kits/ extra/datasets/COCO/ extra/datasets/audio* diff --git a/examples/mamba.py b/examples/mamba.py new file mode 100644 index 0000000000..c016e39c18 --- /dev/null +++ b/examples/mamba.py @@ -0,0 +1,470 @@ +import os, sys, math, argparse +from tqdm import tqdm +sys.path.append(os.getcwd()) +from typing import Any, Optional +from dataclasses import dataclass, field +import time +from tinygrad import Tensor, dtypes, nn +from tinygrad.engine.jit import TinyJit +from tinygrad.helpers import fetch +from tinygrad.nn.state import get_state_dict, load_state_dict, torch_load + +from extra.models.llama import RMSNorm +from transformers import AutoTokenizer +# from einops import rearrange, repeat +import numpy as np + +MODELS = { + "130m": { + "dim": 768, + "n_layers": 24, + "vocab_size": 50277, + "pad_vocab_size_multiple": 8, + }, + "370m": { + "dim": 1024, + "n_layers": 48, + "vocab_size": 50277, + "pad_vocab_size_multiple": 8, + }, + "790m": { + "dim": 1536, + "n_layers": 48, + "vocab_size": 50277, + "pad_vocab_size_multiple": 8, + }, + "1.4b": { + "dim": 2048, + "n_layer": 48, + "vocab_size": 50277, + "pad_vocab_size_multiple": 8, + }, + "2.8b": { + "dim": 2560, + "n_layer": 64, + "vocab_size": 50277, + "pad_vocab_size_multiple": 8, + }, +} + + +def fetch_weights(model_name: str): + if model_name not in MODELS.keys(): + raise Exception(f"Requested unknown mamba model: {model_name}") + downloaded = fetch( + f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true" + ) + weights = torch_load(downloaded) + return weights + + +def selective_scan_ref( + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, +): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = delta.softplus() + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = len(B.shape) >= 3 + is_variable_C = len(C.shape) >= 3 + x = Tensor.zeros(batch, dim, dstate) + ys = [] + deltaA = Tensor.einsum("bdl,dn->bdln", delta, A).exp() + if not is_variable_B: + deltaB_u = Tensor.einsum("bdl,dn,bdl->bdln", delta, B, u) + else: + if len(B.shape) == 3: + deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u) + else: + B = B.repeat((1,dim//B.shape[1],1,1)) + deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u) + if is_variable_C and len(C.shape) == 4: + C = C.repeat((1,dim//C.shape[1],1,1)) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = Tensor.einsum("bdn,dn->bd", x, C) + else: + if len(C.shape) == 3: + y = Tensor.einsum("bdn,bn->bd", x, C[:, :, i]) + else: + y = Tensor.einsum("bdn,bdn->bd", x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + ys.append(y) + y = Tensor.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * D.reshape((D.numel(), 1)) + if z is not None: + out = out * z.silu() + return out if not return_last_state else (out, last_state) + + +class MambaMixer: + def __init__( + self, + dim, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + layer_idx=None, + ): + self.dim = dim + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.dim) + self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank + self.layer_idx = layer_idx + + self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias) + + self.conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_inner, + padding=d_conv - 1, + ) + + self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False) + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = self.dt_rank**-0.5 * dt_scale + if dt_init == "constant": + self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std) + elif dt_init == "random": + self.dt_proj.weight = Tensor.uniform( + self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std + ) + else: + raise NotImplementedError + + dt = ( + ( + Tensor.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + .exp() + .maximum(dt_init_floor) + ) + inv_dt = ( + dt + (-((-dt).exp() - Tensor.ones(*dt.shape))).log() + ) + + self.dt_proj.bias.assign(inv_dt) + + # S4D real initialization + self.A_log = ( + Tensor.arange(1, self.d_state + 1).repeat([self.d_inner, 1]).contiguous().log() + ) + + # D "skip" parameter + self.D = Tensor.ones(self.d_inner) # Keep in fp32 + + self.out_proj = nn.Linear(self.d_inner, self.dim, bias=bias) + + def __call__(self, hidden_states: Tensor, inference_params=None): + batch, seqlen, dim = hidden_states.shape + + conv_state, ssm_state = None, None + if inference_params is not None: + conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, _, _ = self.step(hidden_states[:, -1:, :], conv_state, ssm_state) + return out + + xz = self.in_proj.weight @ hidden_states.permute(2,0,1).reshape(hidden_states.shape[2],hidden_states.shape[1]*hidden_states.shape[0]) + xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2) + + if self.in_proj.bias is not None: + xz = xz + self.in_proj.bias.reshape((self.in_proj.bias.numel(), 1)) + + A = -self.A_log.exp() + x, z = xz.chunk(2, dim=1) + # Compute short convolution + if conv_state is not None: + conv_state.assign(x[:, :, -self.d_conv :]) # Update state (B D W) + x = self.conv1d(x)[..., :seqlen].swish() + + x_dbl = self.x_proj(x.permute(0,2,1).reshape(x.shape[0]*x.shape[2], x.shape[1])) + dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.T + dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2) + B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1).contiguous() + C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1).contiguous() + + y = selective_scan_ref( # TODO: actually implement selective_scan_fn + x, + dt, + A, + B, + C, + self.D, + z=z, + delta_bias=self.dt_proj.bias, + delta_softplus=True, + return_last_state=ssm_state is not None, + ) + + if ssm_state is not None: + y, last_state = y + ssm_state.assign(last_state) + + y = y.permute(0,2,1) + out = self.out_proj(y) + + return out + + def step(self, hidden_states, conv_state, ssm_state): + assert ( + hidden_states.shape[1] == 1 + ), f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}" + xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + x, z = xz.chunk(2, dim=-1) # (B D) + + # Conv step + conv_state.assign(conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1)) + x = (conv_state * self.conv1d.weight.reshape(self.conv1d.weight.shape[0],self.conv1d.weight.shape[2])).sum(-1) + if self.conv1d.bias is not None: + x = x + self.conv1d.bias + x = x.swish() + + x_db = self.x_proj(x) # (B dt_rank+2*d_state) + dt = x_db[:, : self.dt_rank] + B = x_db[:, self.dt_rank : (self.dt_rank + self.d_state)] + C = x_db[:, (self.dt_rank + self.d_state) :] + # Don't add dt_bias here + dt = self.dt_proj.weight @ dt.T + A = -self.A_log.exp() + + + # SSM step + dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus() + # TODO: Tensor.einsum? + dA = Tensor.einsum("db,dn->bdn", dt, A).exp() + dB = Tensor.einsum("db,bn->bdn", dt, B) + ssm_state.assign(ssm_state * dA + x.reshape(x.shape[0],x.shape[1], 1) * dB) + y = Tensor.einsum("bdn,bn->bd", ssm_state, C) + y = y + self.D * x + y = y * z.swish() # (B D) + + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state + + def _get_states_from_cache( + self, inference_params, batch_size, initialize_states=False + ): + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + batch_shape = (batch_size,) + conv_state = Tensor.zeros( + batch_size, self.dim * self.expand, self.d_conv + ).contiguous().realize() + ssm_state = Tensor.zeros( + batch_size, self.dim * self.expand, self.d_state + ).realize() + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + return conv_state, ssm_state + + +class MambaBlock: + def __init__( + self, + dim: int, + norm_eps: float = 1e-5, + rms_norm: bool = True, + layer_idx: Optional[int] = None, + ): + self.mixer = MambaMixer(dim, layer_idx=layer_idx) + if rms_norm: + self.norm = RMSNorm(dim, norm_eps) + else: + raise NotImplementedError + + def __call__( + self, + hidden_states: Tensor, + residual: Optional[Tensor] = None, + inference_params=None, + ): + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm(residual) + hidden_states = self.mixer(hidden_states, inference_params=inference_params) + return hidden_states, residual + + +class MambaBackbone: + def __init__( + self, + dim: int, + n_layers: int, + vocab_size: int, + rms_norm: bool = True, + norm_eps: float = 1e-5, + ): + self.embedding = nn.Embedding(vocab_size, dim) + self.layers = [ + MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers) + ] + if rms_norm: + self.norm_f = RMSNorm(dim, norm_eps) + + def __call__(self, input_ids: Tensor, inference_params=None) -> Any: + hidden_states = self.embedding(input_ids) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + hidden_states, residual, inference_params=inference_params + ) + + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm_f(residual) + + return hidden_states + + +class Mamba: + def __init__( + self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1 + ): + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + + self.backbone = MambaBackbone(dim, n_layers, vocab_size) + self.lm_head = nn.Linear(dim, vocab_size, bias=False) + + self.forward_jit = TinyJit(self.forward) + + def forward(self, input_ids, inference_params, num_last_tokens): + hidden_states = self.backbone(input_ids, inference_params=inference_params) + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] + return self.lm_head(hidden_states).realize() + + def __call__(self, input_ids, inference_params=None, num_last_tokens=0, jit=True): + if inference_params is None: + return self.forward(input_ids, inference_params, num_last_tokens) + if jit and inference_params.seqlen_offset > 0: + return self.forward_jit(input_ids, inference_params, num_last_tokens) + else: + return self.forward(input_ids, inference_params, num_last_tokens) + @staticmethod + def from_pretrained(model_name: str): + weights = fetch_weights(model_name) + model = Mamba(**MODELS[model_name]) + load_state_dict(model, weights) + + return model + +@dataclass +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + max_seqlen: int + max_batch_size: int + seqlen_offset: int = 0 + batch_size_offset: int = 0 + key_value_memory_dict: dict = field(default_factory=dict) + lengths_per_sample: Optional[Tensor] = None + + def reset(self, max_seqlen, max_batch_size): + self.max_seqlen = max_seqlen + self.max_batch_size = max_batch_size + self.seqlen_offset = 0 + if self.lengths_per_sample is not None: + self.lengths_per_sample.zero_() + +def generate(model, + tokenizer, + prompt: str, + n_tokens_to_gen: int = 10, + sample: bool = False, + top_k: int = None): + tks = tokenizer(prompt)["input_ids"] + while(len(tks)<4):tks = [50279] + tks + temperature = 0.5 + start_pos = 0 + inference_params = InferenceParams(max_seqlen=1, max_batch_size=1, seqlen_offset=0) + for i in tqdm(range(n_tokens_to_gen), desc="Speed Gen"): + logits = model(Tensor([tks[start_pos:]]), inference_params, start_pos, jit=False) + inference_params.seqlen_offset = len(tks) + tok = (logits[:, -1, :]).softmax().argmax(axis=-1).item() + start_pos = len(tks) + tks.append(tok) + output_completions = ''.join([tokenizer.decode(output) for output in tks]) + return output_completions + +if __name__ == "__main__": + TORCHOUTPUT = '''Why is gravity \nso important?\nBecause it's the only''' + + parser = argparse.ArgumentParser( + description="Run Mamba in tinygrad", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--prompt", type=str, default='Why is gravity ', help="Prompt for LLM completion" + ) + parser.add_argument( + "--size", + type=str, + default="370m", + help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]", + ) + parser.add_argument( + "--n_tokens", + type=int, + default=10, + help="Number of tokens to generate", + ) + args = parser.parse_args() + + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + model = Mamba.from_pretrained(args.size) + prompt = args.prompt + num_toks = args.n_tokens + s = time.time() + tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks) + print(tinyoutput) + print('TIME: ', time.time() - s) + print('Outputs Match:', tinyoutput == TORCHOUTPUT) diff --git a/test/models/test_mamba.py b/test/models/test_mamba.py new file mode 100644 index 0000000000..3213b7c6b9 --- /dev/null +++ b/test/models/test_mamba.py @@ -0,0 +1,24 @@ +import unittest +from tinygrad.helpers import CI +from examples.mamba import Mamba, generate +from transformers import AutoTokenizer + +PROMPT = 'Why is gravity ' +TOKENIZER = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + +@unittest.skipIf(CI, "model is slow for CI") +class TestMamba(unittest.TestCase): + def test_mamba_130M(self): + OUT_130M = '''Why is gravity \nnot a good idea?\n\nA:''' + model = Mamba.from_pretrained('130m') + tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10) + self.assertEqual(OUT_130M, tinyoutput) + del model + def test_mamba_370M(self): + OUT_370M = '''Why is gravity \nso important?\nBecause it's the only''' + model = Mamba.from_pretrained('370m') + tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10) + self.assertEqual(OUT_370M, tinyoutput) + del model +if __name__ == '__main__': + unittest.main()