mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add support for sharded Falcon model
This commit is contained in:
143
apps/language_models/src/model_wrappers/falcon_sharded_model.py
Normal file
143
apps/language_models/src/model_wrappers/falcon_sharded_model.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class WordEmbeddingsLayer(torch.nn.Module):
|
||||
def __init__(self, word_embedding_layer):
|
||||
super().__init__()
|
||||
self.model = word_embedding_layer
|
||||
|
||||
def forward(self, input_ids):
|
||||
output = self.model.forward(input=input_ids)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledWordEmbeddingsLayer(torch.nn.Module):
|
||||
def __init__(self, compiled_word_embedding_layer):
|
||||
super().__init__()
|
||||
self.model = compiled_word_embedding_layer
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids = input_ids.detach().numpy()
|
||||
new_input_ids = self.model("forward", input_ids)
|
||||
new_input_ids = new_input_ids.reshape(
|
||||
[1, new_input_ids.shape[0], new_input_ids.shape[1]]
|
||||
)
|
||||
return torch.tensor(new_input_ids)
|
||||
|
||||
|
||||
class LNFEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, ln_f):
|
||||
super().__init__()
|
||||
self.model = ln_f
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model.forward(input=hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledLNFEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, ln_f):
|
||||
super().__init__()
|
||||
self.model = ln_f
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach().numpy()
|
||||
new_hidden_states = self.model("forward", (hidden_states,))
|
||||
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class LMHeadEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, embedding_layer):
|
||||
super().__init__()
|
||||
self.model = embedding_layer
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model.forward(input=hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, lm_head):
|
||||
super().__init__()
|
||||
self.model = lm_head
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach().numpy()
|
||||
new_hidden_states = self.model("forward", (hidden_states,))
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class DecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
output = self.model.forward(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
return (output[0], output[1][0], output[1][1])
|
||||
|
||||
|
||||
class CompiledDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, shark_decoder_layer_module):
|
||||
super().__init__()
|
||||
self.model = shark_decoder_layer_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
alibi: torch.Tensor = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
new_hidden_states, pkv1, pkv2 = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
return tuple(
|
||||
[
|
||||
torch.tensor(new_hidden_states),
|
||||
torch.tensor(pkv1),
|
||||
torch.tensor(pkv2),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ShardedFalconModel:
|
||||
def __init__(self, model, layers, word_embeddings, ln_f, lm_head):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.model.transformer.h = torch.nn.modules.container.ModuleList(
|
||||
layers
|
||||
)
|
||||
self.model.transformer.word_embeddings = word_embeddings
|
||||
self.model.transformer.ln_f = ln_f
|
||||
self.model.lm_head = lm_head
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
):
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
).logits[:, -1, :]
|
||||
@@ -1,4 +1,15 @@
|
||||
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
|
||||
from apps.language_models.src.model_wrappers.falcon_sharded_model import (
|
||||
WordEmbeddingsLayer,
|
||||
CompiledWordEmbeddingsLayer,
|
||||
LNFEmbeddingLayer,
|
||||
CompiledLNFEmbeddingLayer,
|
||||
LMHeadEmbeddingLayer,
|
||||
CompiledLMHeadEmbeddingLayer,
|
||||
DecoderLayer,
|
||||
CompiledDecoderLayer,
|
||||
ShardedFalconModel,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
@@ -67,9 +78,16 @@ parser.add_argument(
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication token for falcon-180B model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--sharded",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model as sharded",
|
||||
)
|
||||
|
||||
|
||||
class Falcon(SharkLLMBase):
|
||||
class ShardedFalcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
@@ -85,6 +103,474 @@ class Falcon(SharkLLMBase):
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print("hf_model_path: ", self.hf_model_path)
|
||||
|
||||
if "40b" in self.model_name:
|
||||
raise NotImplementedError(
|
||||
"Sharded Falcon not supported for 40b variant"
|
||||
)
|
||||
|
||||
if (
|
||||
"180b" in self.model_name
|
||||
and precision != "int4"
|
||||
and hf_auth_token == None
|
||||
):
|
||||
raise ValueError(
|
||||
""" HF auth token required for falcon-180b. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.falcon_vmfb_path = falcon_vmfb_path
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
token=self.hf_auth_token,
|
||||
)
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token_id = 11
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
if self.precision == "int4":
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
if self.precision == "int4":
|
||||
falcon_model = falcon_model.to(torch.float32)
|
||||
return falcon_model
|
||||
|
||||
def compile_layer(self, layer, falconCompileInput, layer_id):
|
||||
self.falcon_mlir_path = Path(
|
||||
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir"
|
||||
)
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/sharded/vmfb/"
|
||||
+ str(self.falcon_vmfb_path),
|
||||
self.falcon_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path, self.device, "linalg"
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
if args.load_mlir_from_shark_tank:
|
||||
# Downloading MLIR from shark_tank
|
||||
print(f"[DEBUG] Trying to download mlir from shark_tank")
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/sharded/mlir/"
|
||||
+ str(self.falcon_mlir_path),
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(
|
||||
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] generating MLIR locally")
|
||||
if layer_id == "word_embeddings":
|
||||
f16_input_mask = [False]
|
||||
elif layer_id in ["ln_f", "lm_head"]:
|
||||
f16_input_mask = [True]
|
||||
elif type(layer_id) == int:
|
||||
f16_input_mask = [True, False]
|
||||
else:
|
||||
raise ValueError("Unsupported layer: ", layer_id)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
layer,
|
||||
falconCompileInput,
|
||||
is_f16=True,
|
||||
f16_input_mask=f16_input_mask,
|
||||
mlir_type="torchscript",
|
||||
is_gptq=True,
|
||||
)
|
||||
del layer
|
||||
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
falconCompileInput,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
f_ = open(self.falcon_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
|
||||
f_.close()
|
||||
del bytecode
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=self.falcon_mlir_path,
|
||||
device=self.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
self.falcon_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
+ [
|
||||
"--iree-llvmcpu-use-fast-min-max-ops",
|
||||
]
|
||||
if self.precision == "int4"
|
||||
else [],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile(self):
|
||||
sample_input_ids = torch.zeros([100], dtype=torch.int64)
|
||||
sample_attention_mask = torch.zeros(
|
||||
[1, 1, 100, 100], dtype=torch.float32
|
||||
)
|
||||
if "7b" in self.model_name:
|
||||
num_in_features = 4544
|
||||
else:
|
||||
num_in_features = 14848
|
||||
sample_attention_mask.to(dtype=torch.bool)
|
||||
|
||||
sample_hidden_states = torch.zeros(
|
||||
[1, 100, num_in_features], dtype=torch.float32
|
||||
)
|
||||
|
||||
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
|
||||
print("Compiling Layer lm_head")
|
||||
shark_lm_head = self.compile_layer(
|
||||
lm_head, [sample_hidden_states], "lm_head"
|
||||
)
|
||||
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
|
||||
|
||||
word_embedding = WordEmbeddingsLayer(
|
||||
self.src_model.transformer.word_embeddings
|
||||
)
|
||||
print("Compiling Layer word_embeddings")
|
||||
shark_word_embedding = self.compile_layer(
|
||||
word_embedding, [sample_input_ids], "word_embeddings"
|
||||
)
|
||||
shark_word_embedding = CompiledWordEmbeddingsLayer(
|
||||
shark_word_embedding
|
||||
)
|
||||
|
||||
ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f)
|
||||
print("Compiling Layer ln_f")
|
||||
shark_ln_f = self.compile_layer(ln_f, [sample_hidden_states], "ln_f")
|
||||
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
|
||||
|
||||
shark_layers = []
|
||||
for i in range(len(self.src_model.transformer.h)):
|
||||
print("Compiling Layer {}".format(i))
|
||||
layer_i = self.src_model.transformer.h[i]
|
||||
pytorch_layer_i = DecoderLayer(layer_i)
|
||||
shark_module = self.compile_layer(
|
||||
pytorch_layer_i,
|
||||
[sample_hidden_states, sample_attention_mask],
|
||||
i,
|
||||
)
|
||||
shark_layer_i = CompiledDecoderLayer(shark_module)
|
||||
shark_layers.append(shark_layer_i)
|
||||
|
||||
sharded_model = ShardedFalconModel(
|
||||
self.src_model,
|
||||
shark_layers,
|
||||
shark_word_embedding,
|
||||
shark_ln_f,
|
||||
shark_lm_head,
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
def generate(self, prompt):
|
||||
model_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.max_padding_length,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
model_inputs["prompt_text"] = prompt
|
||||
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs.get("attention_mask", None)
|
||||
|
||||
# Allow empty prompts
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
in_b = 1
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
|
||||
generate_kwargs = {
|
||||
"max_length": self.max_num_tokens,
|
||||
"do_sample": True,
|
||||
"top_k": 10,
|
||||
"num_return_sequences": 1,
|
||||
"eos_token_id": 11,
|
||||
}
|
||||
generate_kwargs["input_ids"] = input_ids
|
||||
generate_kwargs["attention_mask"] = attention_mask
|
||||
generation_config_ = GenerationConfig.from_model_config(
|
||||
self.src_model.config
|
||||
)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
model_kwargs,
|
||||
) = self.src_model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs[
|
||||
"output_hidden_states"
|
||||
] = generation_config.output_hidden_states
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
self.logits_processor = self.src_model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids.shape[-1],
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
self.stopping_criteria = self.src_model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
self.logits_warper = self.src_model._get_logits_warper(
|
||||
generation_config
|
||||
)
|
||||
|
||||
(
|
||||
self.input_ids,
|
||||
self.model_kwargs,
|
||||
) = self.src_model._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id) if eos_token_id is not None else None
|
||||
)
|
||||
|
||||
self.pad_token_id = generation_config.pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
output_attentions = generation_config.output_attentions # False
|
||||
output_hidden_states = generation_config.output_hidden_states # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
self.scores = (
|
||||
() if (return_dict_in_generate and output_scores) else None
|
||||
)
|
||||
decoder_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
cross_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
decoder_hidden_states = (
|
||||
() if (return_dict_in_generate and output_hidden_states) else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
self.unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
|
||||
all_text = prompt
|
||||
|
||||
for i in range(self.max_num_tokens - 1):
|
||||
next_token = self.generate_new_token()
|
||||
new_word = self.tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
all_text = all_text + new_word
|
||||
|
||||
print(f"{new_word}", end="", flush=True)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if self.eos_token_id_tensor is not None:
|
||||
self.unfinished_sequences = self.unfinished_sequences.mul(
|
||||
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
|
||||
.ne(self.eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if (
|
||||
self.unfinished_sequences.max() == 0
|
||||
or self.stopping_criteria(input_ids, self.scores)
|
||||
):
|
||||
break
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return all_text
|
||||
|
||||
def generate_new_token(self):
|
||||
model_inputs = self.src_model.prepare_inputs_for_generation(
|
||||
self.input_ids, **self.model_kwargs
|
||||
)
|
||||
outputs = self.shark_model.forward(
|
||||
input_ids=model_inputs["input_ids"],
|
||||
attention_mask=model_inputs["attention_mask"],
|
||||
)
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = self.logits_processor(
|
||||
self.input_ids, next_token_logits
|
||||
)
|
||||
next_token_scores = self.logits_warper(
|
||||
self.input_ids, next_token_scores
|
||||
)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if self.eos_token_id is not None:
|
||||
if self.pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = (
|
||||
next_token * self.unfinished_sequences
|
||||
+ self.pad_token_id * (1 - self.unfinished_sequences)
|
||||
)
|
||||
|
||||
self.input_ids = torch.cat(
|
||||
[self.input_ids, next_token[:, None]], dim=-1
|
||||
)
|
||||
|
||||
self.model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in self.model_kwargs:
|
||||
attention_mask = self.model_kwargs["attention_mask"]
|
||||
self.model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
self.input_ids = self.input_ids[:, 1:]
|
||||
self.model_kwargs["attention_mask"] = self.model_kwargs[
|
||||
"attention_mask"
|
||||
][:, 1:]
|
||||
|
||||
return next_token
|
||||
|
||||
|
||||
class UnshardedFalcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="tiiuae/falcon-7b-instruct",
|
||||
hf_auth_token: str = "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",
|
||||
max_num_tokens=150,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print("hf_model_path: ", self.hf_model_path)
|
||||
|
||||
if "180b" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
""" HF auth token required for falcon-180b. Pass it using
|
||||
@@ -519,14 +1005,22 @@ if __name__ == "__main__":
|
||||
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
|
||||
)
|
||||
|
||||
falcon = Falcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
falcon_vmfb_path=falcon_vmfb_path,
|
||||
)
|
||||
if not args.sharded:
|
||||
falcon = UnshardedFalcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
falcon_vmfb_path=falcon_vmfb_path,
|
||||
)
|
||||
else:
|
||||
falcon = ShardedFalcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
)
|
||||
|
||||
import gc
|
||||
|
||||
|
||||
Reference in New Issue
Block a user