mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
cleanup mamba (#4004)
make it read nicer and cleanup some movement methods and math simplification. 790m, 1.4b, 2.8b model does not really run. sampling is not implemented. jit is incorrect. some deadcode / wrong code path and copied from torch stuff stuff.
This commit is contained in:
@@ -1,62 +1,29 @@
|
|||||||
import os, sys, math, argparse
|
import os, sys, math, argparse, time
|
||||||
from tqdm import tqdm
|
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, Dict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import time
|
|
||||||
from tinygrad import Tensor, dtypes, nn
|
|
||||||
from tinygrad.engine.jit import TinyJit
|
|
||||||
from tinygrad.helpers import fetch
|
|
||||||
from tinygrad.nn.state import get_state_dict, load_state_dict, torch_load
|
|
||||||
|
|
||||||
|
from tinygrad import Tensor, TinyJit, nn
|
||||||
|
from tinygrad.helpers import fetch
|
||||||
|
from tinygrad.nn.state import load_state_dict, torch_load
|
||||||
from extra.models.llama import RMSNorm
|
from extra.models.llama import RMSNorm
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
# from einops import rearrange, repeat
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
MODELS = {
|
MODELS = {
|
||||||
"130m": {
|
"130m": {"dim": 768, "n_layers": 24, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
|
||||||
"dim": 768,
|
"370m": {"dim": 1024, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
|
||||||
"n_layers": 24,
|
"790m": {"dim": 1536, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
|
||||||
"vocab_size": 50277,
|
"1.4b": {"dim": 2048, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
|
||||||
"pad_vocab_size_multiple": 8,
|
"2.8b": {"dim": 2560, "n_layers": 64, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
|
||||||
},
|
|
||||||
"370m": {
|
|
||||||
"dim": 1024,
|
|
||||||
"n_layers": 48,
|
|
||||||
"vocab_size": 50277,
|
|
||||||
"pad_vocab_size_multiple": 8,
|
|
||||||
},
|
|
||||||
"790m": {
|
|
||||||
"dim": 1536,
|
|
||||||
"n_layers": 48,
|
|
||||||
"vocab_size": 50277,
|
|
||||||
"pad_vocab_size_multiple": 8,
|
|
||||||
},
|
|
||||||
"1.4b": {
|
|
||||||
"dim": 2048,
|
|
||||||
"n_layer": 48,
|
|
||||||
"vocab_size": 50277,
|
|
||||||
"pad_vocab_size_multiple": 8,
|
|
||||||
},
|
|
||||||
"2.8b": {
|
|
||||||
"dim": 2560,
|
|
||||||
"n_layer": 64,
|
|
||||||
"vocab_size": 50277,
|
|
||||||
"pad_vocab_size_multiple": 8,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def fetch_weights(model_name: str) -> Dict[str, Tensor]:
|
||||||
def fetch_weights(model_name: str):
|
if model_name not in MODELS:
|
||||||
if model_name not in MODELS.keys():
|
raise ValueError(f"Requested unknown mamba model: {model_name}")
|
||||||
raise Exception(f"Requested unknown mamba model: {model_name}")
|
downloaded = fetch(f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true")
|
||||||
downloaded = fetch(
|
return torch_load(downloaded)
|
||||||
f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true"
|
|
||||||
)
|
|
||||||
weights = torch_load(downloaded)
|
|
||||||
return weights
|
|
||||||
|
|
||||||
|
|
||||||
def selective_scan_ref(
|
def selective_scan_ref(
|
||||||
u,
|
u,
|
||||||
@@ -83,7 +50,6 @@ def selective_scan_ref(
|
|||||||
out: r(B D L)
|
out: r(B D L)
|
||||||
last_state (optional): r(B D dstate) or c(B D dstate)
|
last_state (optional): r(B D dstate) or c(B D dstate)
|
||||||
"""
|
"""
|
||||||
dtype_in = u.dtype
|
|
||||||
u = u.float()
|
u = u.float()
|
||||||
delta = delta.float()
|
delta = delta.float()
|
||||||
if delta_bias is not None:
|
if delta_bias is not None:
|
||||||
@@ -102,10 +68,10 @@ def selective_scan_ref(
|
|||||||
if len(B.shape) == 3:
|
if len(B.shape) == 3:
|
||||||
deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
||||||
else:
|
else:
|
||||||
B = B.repeat((1,dim//B.shape[1],1,1))
|
B = B.repeat((1, dim // B.shape[1], 1, 1))
|
||||||
deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
||||||
if is_variable_C and len(C.shape) == 4:
|
if is_variable_C and len(C.shape) == 4:
|
||||||
C = C.repeat((1,dim//C.shape[1],1,1))
|
C = C.repeat((1, dim // C.shape[1], 1, 1))
|
||||||
last_state = None
|
last_state = None
|
||||||
for i in range(u.shape[2]):
|
for i in range(u.shape[2]):
|
||||||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||||||
@@ -120,12 +86,11 @@ def selective_scan_ref(
|
|||||||
last_state = x
|
last_state = x
|
||||||
ys.append(y)
|
ys.append(y)
|
||||||
y = Tensor.stack(ys, dim=2) # (batch dim L)
|
y = Tensor.stack(ys, dim=2) # (batch dim L)
|
||||||
out = y if D is None else y + u * D.reshape((D.numel(), 1))
|
out = y if D is None else y + u * D.reshape((-1, 1))
|
||||||
if z is not None:
|
if z is not None:
|
||||||
out = out * z.silu()
|
out = out * z.silu()
|
||||||
return out if not return_last_state else (out, last_state)
|
return out if not return_last_state else (out, last_state)
|
||||||
|
|
||||||
|
|
||||||
class MambaMixer:
|
class MambaMixer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -147,20 +112,14 @@ class MambaMixer:
|
|||||||
self.d_state = d_state
|
self.d_state = d_state
|
||||||
self.d_conv = d_conv
|
self.d_conv = d_conv
|
||||||
self.expand = expand
|
self.expand = expand
|
||||||
self.d_inner = int(self.expand * self.dim)
|
self.d_inner = self.expand * self.dim
|
||||||
self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank
|
self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias)
|
self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias)
|
||||||
|
|
||||||
self.conv1d = nn.Conv1d(
|
self.conv1d = nn.Conv1d(in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias,
|
||||||
in_channels=self.d_inner,
|
kernel_size=d_conv, groups=self.d_inner, padding=d_conv-1)
|
||||||
out_channels=self.d_inner,
|
|
||||||
bias=conv_bias,
|
|
||||||
kernel_size=d_conv,
|
|
||||||
groups=self.d_inner,
|
|
||||||
padding=d_conv - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
|
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
|
||||||
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
||||||
@@ -170,30 +129,17 @@ class MambaMixer:
|
|||||||
if dt_init == "constant":
|
if dt_init == "constant":
|
||||||
self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std)
|
self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std)
|
||||||
elif dt_init == "random":
|
elif dt_init == "random":
|
||||||
self.dt_proj.weight = Tensor.uniform(
|
self.dt_proj.weight = Tensor.uniform(self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std)
|
||||||
self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
dt = (
|
dt = Tensor.uniform(self.d_inner, low=math.log(dt_min), high=math.log(dt_max)).exp().maximum(dt_init_floor)
|
||||||
(
|
inv_dt = dt + (1 - (-dt).exp()).log()
|
||||||
Tensor.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
|
|
||||||
+ math.log(dt_min)
|
|
||||||
)
|
|
||||||
.exp()
|
|
||||||
.maximum(dt_init_floor)
|
|
||||||
)
|
|
||||||
inv_dt = (
|
|
||||||
dt + (-((-dt).exp() - Tensor.ones(*dt.shape))).log()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.dt_proj.bias.assign(inv_dt)
|
self.dt_proj.bias.assign(inv_dt)
|
||||||
|
|
||||||
# S4D real initialization
|
# S4D real initialization
|
||||||
self.A_log = (
|
self.A_log = Tensor.arange(1, self.d_state+1).repeat([self.d_inner, 1]).log()
|
||||||
Tensor.arange(1, self.d_state + 1).repeat([self.d_inner, 1]).contiguous().log()
|
|
||||||
)
|
|
||||||
|
|
||||||
# D "skip" parameter
|
# D "skip" parameter
|
||||||
self.D = Tensor.ones(self.d_inner) # Keep in fp32
|
self.D = Tensor.ones(self.d_inner) # Keep in fp32
|
||||||
@@ -215,7 +161,7 @@ class MambaMixer:
|
|||||||
xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
|
xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
|
||||||
|
|
||||||
if self.in_proj.bias is not None:
|
if self.in_proj.bias is not None:
|
||||||
xz = xz + self.in_proj.bias.reshape((self.in_proj.bias.numel(), 1))
|
xz = xz + self.in_proj.bias.reshape((-1, 1))
|
||||||
|
|
||||||
A = -self.A_log.exp()
|
A = -self.A_log.exp()
|
||||||
x, z = xz.chunk(2, dim=1)
|
x, z = xz.chunk(2, dim=1)
|
||||||
@@ -228,41 +174,28 @@ class MambaMixer:
|
|||||||
dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
dt = self.dt_proj.weight @ dt.T
|
dt = self.dt_proj.weight @ dt.T
|
||||||
dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
|
dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
|
||||||
B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1).contiguous()
|
B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1)
|
||||||
C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1).contiguous()
|
C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1)
|
||||||
|
|
||||||
y = selective_scan_ref( # TODO: actually implement selective_scan_fn
|
|
||||||
x,
|
|
||||||
dt,
|
|
||||||
A,
|
|
||||||
B,
|
|
||||||
C,
|
|
||||||
self.D,
|
|
||||||
z=z,
|
|
||||||
delta_bias=self.dt_proj.bias,
|
|
||||||
delta_softplus=True,
|
|
||||||
return_last_state=ssm_state is not None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# TODO: actually implement selective_scan_fn
|
||||||
|
y = selective_scan_ref(x, dt, A, B, C, self.D, z=z, delta_bias=self.dt_proj.bias, delta_softplus=True,
|
||||||
|
return_last_state=ssm_state is not None)
|
||||||
if ssm_state is not None:
|
if ssm_state is not None:
|
||||||
y, last_state = y
|
y, last_state = y
|
||||||
ssm_state.assign(last_state)
|
ssm_state.assign(last_state)
|
||||||
|
|
||||||
y = y.permute(0,2,1)
|
y = y.permute(0,2,1)
|
||||||
out = self.out_proj(y)
|
out = self.out_proj(y)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def step(self, hidden_states, conv_state, ssm_state):
|
def step(self, hidden_states: Tensor, conv_state: Tensor, ssm_state: Tensor):
|
||||||
assert (
|
assert hidden_states.shape[1] == 1, f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
|
||||||
hidden_states.shape[1] == 1
|
|
||||||
), f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
|
|
||||||
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
||||||
x, z = xz.chunk(2, dim=-1) # (B D)
|
x, z = xz.chunk(2, dim=-1) # (B D)
|
||||||
|
|
||||||
# Conv step
|
# Conv step
|
||||||
conv_state.assign(conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1))
|
conv_state.assign(conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1))
|
||||||
x = (conv_state * self.conv1d.weight.reshape(self.conv1d.weight.shape[0],self.conv1d.weight.shape[2])).sum(-1)
|
x = (conv_state * self.conv1d.weight.squeeze(1)).sum(-1)
|
||||||
if self.conv1d.bias is not None:
|
if self.conv1d.bias is not None:
|
||||||
x = x + self.conv1d.bias
|
x = x + self.conv1d.bias
|
||||||
x = x.swish()
|
x = x.swish()
|
||||||
@@ -275,13 +208,11 @@ class MambaMixer:
|
|||||||
dt = self.dt_proj.weight @ dt.T
|
dt = self.dt_proj.weight @ dt.T
|
||||||
A = -self.A_log.exp()
|
A = -self.A_log.exp()
|
||||||
|
|
||||||
|
|
||||||
# SSM step
|
# SSM step
|
||||||
dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus()
|
dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus()
|
||||||
# TODO: Tensor.einsum?
|
|
||||||
dA = Tensor.einsum("db,dn->bdn", dt, A).exp()
|
dA = Tensor.einsum("db,dn->bdn", dt, A).exp()
|
||||||
dB = Tensor.einsum("db,bn->bdn", dt, B)
|
dB = Tensor.einsum("db,bn->bdn", dt, B)
|
||||||
ssm_state.assign(ssm_state * dA + x.reshape(x.shape[0],x.shape[1], 1) * dB)
|
ssm_state.assign(ssm_state * dA + x.unsqueeze(-1) * dB)
|
||||||
y = Tensor.einsum("bdn,bn->bd", ssm_state, C)
|
y = Tensor.einsum("bdn,bn->bd", ssm_state, C)
|
||||||
y = y + self.D * x
|
y = y + self.D * x
|
||||||
y = y * z.swish() # (B D)
|
y = y * z.swish() # (B D)
|
||||||
@@ -289,63 +220,34 @@ class MambaMixer:
|
|||||||
out = self.out_proj(y)
|
out = self.out_proj(y)
|
||||||
return out.unsqueeze(1), conv_state, ssm_state
|
return out.unsqueeze(1), conv_state, ssm_state
|
||||||
|
|
||||||
def _get_states_from_cache(
|
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
||||||
self, inference_params, batch_size, initialize_states=False
|
|
||||||
):
|
|
||||||
assert self.layer_idx is not None
|
assert self.layer_idx is not None
|
||||||
if self.layer_idx not in inference_params.key_value_memory_dict:
|
if self.layer_idx not in inference_params.key_value_memory_dict:
|
||||||
batch_shape = (batch_size,)
|
conv_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_conv).contiguous().realize()
|
||||||
conv_state = Tensor.zeros(
|
ssm_state = Tensor.zeros(batch_size, self.dim * self.expand, self.d_state).realize()
|
||||||
batch_size, self.dim * self.expand, self.d_conv
|
|
||||||
).contiguous().realize()
|
|
||||||
ssm_state = Tensor.zeros(
|
|
||||||
batch_size, self.dim * self.expand, self.d_state
|
|
||||||
).realize()
|
|
||||||
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
||||||
else:
|
else:
|
||||||
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
||||||
return conv_state, ssm_state
|
return conv_state, ssm_state
|
||||||
|
|
||||||
|
|
||||||
class MambaBlock:
|
class MambaBlock:
|
||||||
def __init__(
|
def __init__(self, dim: int, norm_eps: float = 1e-5, rms_norm: bool = True, layer_idx: Optional[int] = None):
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
norm_eps: float = 1e-5,
|
|
||||||
rms_norm: bool = True,
|
|
||||||
layer_idx: Optional[int] = None,
|
|
||||||
):
|
|
||||||
self.mixer = MambaMixer(dim, layer_idx=layer_idx)
|
self.mixer = MambaMixer(dim, layer_idx=layer_idx)
|
||||||
if rms_norm:
|
if rms_norm:
|
||||||
self.norm = RMSNorm(dim, norm_eps)
|
self.norm = RMSNorm(dim, norm_eps)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):
|
||||||
self,
|
|
||||||
hidden_states: Tensor,
|
|
||||||
residual: Optional[Tensor] = None,
|
|
||||||
inference_params=None,
|
|
||||||
):
|
|
||||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||||
hidden_states = self.norm(residual)
|
hidden_states = self.norm(residual)
|
||||||
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
class MambaBackbone:
|
class MambaBackbone:
|
||||||
def __init__(
|
def __init__(self, dim: int, n_layers: int, vocab_size: int, rms_norm: bool = True, norm_eps: float = 1e-5):
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
n_layers: int,
|
|
||||||
vocab_size: int,
|
|
||||||
rms_norm: bool = True,
|
|
||||||
norm_eps: float = 1e-5,
|
|
||||||
):
|
|
||||||
self.embedding = nn.Embedding(vocab_size, dim)
|
self.embedding = nn.Embedding(vocab_size, dim)
|
||||||
self.layers = [
|
self.layers = [MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)]
|
||||||
MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)
|
|
||||||
]
|
|
||||||
if rms_norm:
|
if rms_norm:
|
||||||
self.norm_f = RMSNorm(dim, norm_eps)
|
self.norm_f = RMSNorm(dim, norm_eps)
|
||||||
|
|
||||||
@@ -353,20 +255,15 @@ class MambaBackbone:
|
|||||||
hidden_states = self.embedding(input_ids)
|
hidden_states = self.embedding(input_ids)
|
||||||
residual = None
|
residual = None
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
|
||||||
hidden_states, residual, inference_params=inference_params
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||||
hidden_states = self.norm_f(residual)
|
hidden_states = self.norm_f(residual)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Mamba:
|
class Mamba:
|
||||||
def __init__(
|
def __init__(self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1):
|
||||||
self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1
|
|
||||||
):
|
|
||||||
if vocab_size % pad_vocab_size_multiple != 0:
|
if vocab_size % pad_vocab_size_multiple != 0:
|
||||||
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
||||||
|
|
||||||
@@ -400,7 +297,6 @@ class Mamba:
|
|||||||
class InferenceParams:
|
class InferenceParams:
|
||||||
"""Inference parameters that are passed to the main model in order
|
"""Inference parameters that are passed to the main model in order
|
||||||
to efficienly calculate and store the context during inference."""
|
to efficienly calculate and store the context during inference."""
|
||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
max_batch_size: int
|
max_batch_size: int
|
||||||
seqlen_offset: int = 0
|
seqlen_offset: int = 0
|
||||||
@@ -415,48 +311,29 @@ class InferenceParams:
|
|||||||
if self.lengths_per_sample is not None:
|
if self.lengths_per_sample is not None:
|
||||||
self.lengths_per_sample.zero_()
|
self.lengths_per_sample.zero_()
|
||||||
|
|
||||||
def generate(model,
|
def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, sample: bool = False, top_k: int = None):
|
||||||
tokenizer,
|
|
||||||
prompt: str,
|
|
||||||
n_tokens_to_gen: int = 10,
|
|
||||||
sample: bool = False,
|
|
||||||
top_k: int = None):
|
|
||||||
tks = tokenizer(prompt)["input_ids"]
|
tks = tokenizer(prompt)["input_ids"]
|
||||||
while(len(tks)<4):tks = [50279] + tks
|
while len(tks) < 4:
|
||||||
|
tks = [50279] + tks
|
||||||
|
# TODO: sampling
|
||||||
temperature = 0.5
|
temperature = 0.5
|
||||||
start_pos = 0
|
start_pos = 0
|
||||||
inference_params = InferenceParams(max_seqlen=1, max_batch_size=1, seqlen_offset=0)
|
inference_params = InferenceParams(max_seqlen=1, max_batch_size=1, seqlen_offset=0)
|
||||||
for i in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
|
for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
|
||||||
logits = model(Tensor([tks[start_pos:]]), inference_params, start_pos, jit=False)
|
logits = model(Tensor([tks[start_pos:]]), inference_params, start_pos, jit=False)
|
||||||
inference_params.seqlen_offset = len(tks)
|
inference_params.seqlen_offset = len(tks)
|
||||||
tok = (logits[:, -1, :]).softmax().argmax(axis=-1).item()
|
tok = logits[:, -1, :].argmax(axis=-1).item()
|
||||||
start_pos = len(tks)
|
start_pos = len(tks)
|
||||||
tks.append(tok)
|
tks.append(tok)
|
||||||
output_completions = ''.join([tokenizer.decode(output) for output in tks])
|
output_completions = ''.join([tokenizer.decode(output) for output in tks])
|
||||||
return output_completions
|
return output_completions
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
TORCHOUTPUT = '''Why is gravity \nso important?\nBecause it's the only'''
|
parser = argparse.ArgumentParser(description="Run Mamba in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("--prompt", type=str, default="Why is gravity ", help="Prompt for LLM completion")
|
||||||
parser = argparse.ArgumentParser(
|
parser.add_argument("--size", type=str, default="370m",
|
||||||
description="Run Mamba in tinygrad",
|
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt", type=str, default='Why is gravity ', help="Prompt for LLM completion"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--size",
|
|
||||||
type=str,
|
|
||||||
default="370m",
|
|
||||||
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_tokens",
|
|
||||||
type=int,
|
|
||||||
default=10,
|
|
||||||
help="Number of tokens to generate",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
||||||
@@ -467,4 +344,5 @@ if __name__ == "__main__":
|
|||||||
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks)
|
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks)
|
||||||
print(tinyoutput)
|
print(tinyoutput)
|
||||||
print('TIME: ', time.time() - s)
|
print('TIME: ', time.time() - s)
|
||||||
|
TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"
|
||||||
print('Outputs Match:', tinyoutput == TORCHOUTPUT)
|
print('Outputs Match:', tinyoutput == TORCHOUTPUT)
|
||||||
|
|||||||
Reference in New Issue
Block a user