diff --git a/examples/mamba.py b/examples/mamba.py index c016e39c18..1282f07a1a 100644 --- a/examples/mamba.py +++ b/examples/mamba.py @@ -1,62 +1,29 @@ -import os, sys, math, argparse -from tqdm import tqdm +import os, sys, math, argparse, time sys.path.append(os.getcwd()) -from typing import Any, Optional +from typing import Any, Optional, Dict from dataclasses import dataclass, field -import time -from tinygrad import Tensor, dtypes, nn -from tinygrad.engine.jit import TinyJit -from tinygrad.helpers import fetch -from tinygrad.nn.state import get_state_dict, load_state_dict, torch_load +from tinygrad import Tensor, TinyJit, nn +from tinygrad.helpers import fetch +from tinygrad.nn.state import load_state_dict, torch_load from extra.models.llama import RMSNorm + +from tqdm import tqdm from transformers import AutoTokenizer -# from einops import rearrange, repeat -import numpy as np MODELS = { - "130m": { - "dim": 768, - "n_layers": 24, - "vocab_size": 50277, - "pad_vocab_size_multiple": 8, - }, - "370m": { - "dim": 1024, - "n_layers": 48, - "vocab_size": 50277, - "pad_vocab_size_multiple": 8, - }, - "790m": { - "dim": 1536, - "n_layers": 48, - "vocab_size": 50277, - "pad_vocab_size_multiple": 8, - }, - "1.4b": { - "dim": 2048, - "n_layer": 48, - "vocab_size": 50277, - "pad_vocab_size_multiple": 8, - }, - "2.8b": { - "dim": 2560, - "n_layer": 64, - "vocab_size": 50277, - "pad_vocab_size_multiple": 8, - }, + "130m": {"dim": 768, "n_layers": 24, "vocab_size": 50277, "pad_vocab_size_multiple": 8}, + "370m": {"dim": 1024, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8}, + "790m": {"dim": 1536, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8}, + "1.4b": {"dim": 2048, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8}, + "2.8b": {"dim": 2560, "n_layers": 64, "vocab_size": 50277, "pad_vocab_size_multiple": 8}, } - -def fetch_weights(model_name: str): - if model_name not in MODELS.keys(): - raise Exception(f"Requested unknown mamba model: {model_name}") - downloaded = fetch( - f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true" - ) - weights = torch_load(downloaded) - return weights - +def fetch_weights(model_name: str) -> Dict[str, Tensor]: + if model_name not in MODELS: + raise ValueError(f"Requested unknown mamba model: {model_name}") + downloaded = fetch(f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true") + return torch_load(downloaded) def selective_scan_ref( u, @@ -83,7 +50,6 @@ def selective_scan_ref( out: r(B D L) last_state (optional): r(B D dstate) or c(B D dstate) """ - dtype_in = u.dtype u = u.float() delta = delta.float() if delta_bias is not None: @@ -102,10 +68,10 @@ def selective_scan_ref( if len(B.shape) == 3: deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u) else: - B = B.repeat((1,dim//B.shape[1],1,1)) + B = B.repeat((1, dim // B.shape[1], 1, 1)) deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u) if is_variable_C and len(C.shape) == 4: - C = C.repeat((1,dim//C.shape[1],1,1)) + C = C.repeat((1, dim // C.shape[1], 1, 1)) last_state = None for i in range(u.shape[2]): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] @@ -120,12 +86,11 @@ def selective_scan_ref( last_state = x ys.append(y) y = Tensor.stack(ys, dim=2) # (batch dim L) - out = y if D is None else y + u * D.reshape((D.numel(), 1)) + out = y if D is None else y + u * D.reshape((-1, 1)) if z is not None: out = out * z.silu() return out if not return_last_state else (out, last_state) - class MambaMixer: def __init__( self, @@ -147,20 +112,14 @@ class MambaMixer: self.d_state = d_state self.d_conv = d_conv self.expand = expand - self.d_inner = int(self.expand * self.dim) + self.d_inner = self.expand * self.dim self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank self.layer_idx = layer_idx self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias) - self.conv1d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - ) + self.conv1d = nn.Conv1d(in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias, + kernel_size=d_conv, groups=self.d_inner, padding=d_conv-1) self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) @@ -170,30 +129,17 @@ class MambaMixer: if dt_init == "constant": self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std) elif dt_init == "random": - self.dt_proj.weight = Tensor.uniform( - self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std - ) + self.dt_proj.weight = Tensor.uniform(self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std) else: raise NotImplementedError - dt = ( - ( - Tensor.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ) - .exp() - .maximum(dt_init_floor) - ) - inv_dt = ( - dt + (-((-dt).exp() - Tensor.ones(*dt.shape))).log() - ) + dt = Tensor.uniform(self.d_inner, low=math.log(dt_min), high=math.log(dt_max)).exp().maximum(dt_init_floor) + inv_dt = dt + (1 - (-dt).exp()).log() self.dt_proj.bias.assign(inv_dt) # S4D real initialization - self.A_log = ( - Tensor.arange(1, self.d_state + 1).repeat([self.d_inner, 1]).contiguous().log() - ) + self.A_log = Tensor.arange(1, self.d_state+1).repeat([self.d_inner, 1]).log() # D "skip" parameter self.D = Tensor.ones(self.d_inner) # Keep in fp32 @@ -215,7 +161,7 @@ class MambaMixer: xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2) if self.in_proj.bias is not None: - xz = xz + self.in_proj.bias.reshape((self.in_proj.bias.numel(), 1)) + xz = xz + self.in_proj.bias.reshape((-1, 1)) A = -self.A_log.exp() x, z = xz.chunk(2, dim=1) @@ -228,41 +174,28 @@ class MambaMixer: dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt = self.dt_proj.weight @ dt.T dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2) - B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1).contiguous() - C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1).contiguous() - - y = selective_scan_ref( # TODO: actually implement selective_scan_fn - x, - dt, - A, - B, - C, - self.D, - z=z, - delta_bias=self.dt_proj.bias, - delta_softplus=True, - return_last_state=ssm_state is not None, - ) + B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1) + C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1) + # TODO: actually implement selective_scan_fn + y = selective_scan_ref(x, dt, A, B, C, self.D, z=z, delta_bias=self.dt_proj.bias, delta_softplus=True, + return_last_state=ssm_state is not None) if ssm_state is not None: y, last_state = y ssm_state.assign(last_state) y = y.permute(0,2,1) out = self.out_proj(y) - return out - def step(self, hidden_states, conv_state, ssm_state): - assert ( - hidden_states.shape[1] == 1 - ), f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}" + def step(self, hidden_states: Tensor, conv_state: Tensor, ssm_state: Tensor): + assert hidden_states.shape[1] == 1, f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}" xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) x, z = xz.chunk(2, dim=-1) # (B D) # Conv step conv_state.assign(conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1)) - x = (conv_state * self.conv1d.weight.reshape(self.conv1d.weight.shape[0],self.conv1d.weight.shape[2])).sum(-1) + x = (conv_state * self.conv1d.weight.squeeze(1)).sum(-1) if self.conv1d.bias is not None: x = x + self.conv1d.bias x = x.swish() @@ -275,13 +208,11 @@ class MambaMixer: dt = self.dt_proj.weight @ dt.T A = -self.A_log.exp() - # SSM step dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus() - # TODO: Tensor.einsum? dA = Tensor.einsum("db,dn->bdn", dt, A).exp() dB = Tensor.einsum("db,bn->bdn", dt, B) - ssm_state.assign(ssm_state * dA + x.reshape(x.shape[0],x.shape[1], 1) * dB) + ssm_state.assign(ssm_state * dA + x.unsqueeze(-1) * dB) y = Tensor.einsum("bdn,bn->bd", ssm_state, C) y = y + self.D * x y = y * z.swish() # (B D) @@ -289,63 +220,34 @@ class MambaMixer: out = self.out_proj(y) return out.unsqueeze(1), conv_state, ssm_state - def _get_states_from_cache( - self, inference_params, batch_size, initialize_states=False - ): + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): assert self.layer_idx is not None if self.layer_idx not in inference_params.key_value_memory_dict: - batch_shape = (batch_size,) - conv_state = Tensor.zeros( - batch_size, self.dim * self.expand, self.d_conv - ).contiguous().realize() - ssm_state = Tensor.zeros( - batch_size, self.dim * self.expand, self.d_state - ).realize() + conv_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_conv).contiguous().realize() + ssm_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_state).realize() inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) else: conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] return conv_state, ssm_state - class MambaBlock: - def __init__( - self, - dim: int, - norm_eps: float = 1e-5, - rms_norm: bool = True, - layer_idx: Optional[int] = None, - ): + def __init__(self, dim: int, norm_eps: float = 1e-5, rms_norm: bool = True, layer_idx: Optional[int] = None): self.mixer = MambaMixer(dim, layer_idx=layer_idx) if rms_norm: self.norm = RMSNorm(dim, norm_eps) else: raise NotImplementedError - def __call__( - self, - hidden_states: Tensor, - residual: Optional[Tensor] = None, - inference_params=None, - ): + def __call__(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None): residual = (hidden_states + residual) if residual is not None else hidden_states hidden_states = self.norm(residual) hidden_states = self.mixer(hidden_states, inference_params=inference_params) return hidden_states, residual - class MambaBackbone: - def __init__( - self, - dim: int, - n_layers: int, - vocab_size: int, - rms_norm: bool = True, - norm_eps: float = 1e-5, - ): + def __init__(self, dim: int, n_layers: int, vocab_size: int, rms_norm: bool = True, norm_eps: float = 1e-5): self.embedding = nn.Embedding(vocab_size, dim) - self.layers = [ - MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers) - ] + self.layers = [MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)] if rms_norm: self.norm_f = RMSNorm(dim, norm_eps) @@ -353,20 +255,15 @@ class MambaBackbone: hidden_states = self.embedding(input_ids) residual = None for layer in self.layers: - hidden_states, residual = layer( - hidden_states, residual, inference_params=inference_params - ) + hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params) residual = (hidden_states + residual) if residual is not None else hidden_states hidden_states = self.norm_f(residual) return hidden_states - class Mamba: - def __init__( - self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1 - ): + def __init__(self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1): if vocab_size % pad_vocab_size_multiple != 0: vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) @@ -400,7 +297,6 @@ class Mamba: class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" - max_seqlen: int max_batch_size: int seqlen_offset: int = 0 @@ -415,48 +311,29 @@ class InferenceParams: if self.lengths_per_sample is not None: self.lengths_per_sample.zero_() -def generate(model, - tokenizer, - prompt: str, - n_tokens_to_gen: int = 10, - sample: bool = False, - top_k: int = None): +def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, sample: bool = False, top_k: int = None): tks = tokenizer(prompt)["input_ids"] - while(len(tks)<4):tks = [50279] + tks + while len(tks) < 4: + tks = [50279] + tks + # TODO: sampling temperature = 0.5 start_pos = 0 inference_params = InferenceParams(max_seqlen=1, max_batch_size=1, seqlen_offset=0) - for i in tqdm(range(n_tokens_to_gen), desc="Speed Gen"): + for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"): logits = model(Tensor([tks[start_pos:]]), inference_params, start_pos, jit=False) inference_params.seqlen_offset = len(tks) - tok = (logits[:, -1, :]).softmax().argmax(axis=-1).item() + tok = logits[:, -1, :].argmax(axis=-1).item() start_pos = len(tks) tks.append(tok) output_completions = ''.join([tokenizer.decode(output) for output in tks]) return output_completions if __name__ == "__main__": - TORCHOUTPUT = '''Why is gravity \nso important?\nBecause it's the only''' - - parser = argparse.ArgumentParser( - description="Run Mamba in tinygrad", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--prompt", type=str, default='Why is gravity ', help="Prompt for LLM completion" - ) - parser.add_argument( - "--size", - type=str, - default="370m", - help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]", - ) - parser.add_argument( - "--n_tokens", - type=int, - default=10, - help="Number of tokens to generate", - ) + parser = argparse.ArgumentParser(description="Run Mamba in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--prompt", type=str, default="Why is gravity ", help="Prompt for LLM completion") + parser.add_argument("--size", type=str, default="370m", + help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]") + parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate") args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") @@ -467,4 +344,5 @@ if __name__ == "__main__": tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks) print(tinyoutput) print('TIME: ', time.time() - s) + TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only" print('Outputs Match:', tinyoutput == TORCHOUTPUT)