prevent loading tokenizer on import (#1432)

also adds sentencepiece dep for exe
moved vicuna imports to after an if statement
in general we should avoid importing files that load whole models as
global variables
This commit is contained in:
Daniel Garvey
2023-05-12 21:11:45 -05:00
committed by GitHub
parent 4c07e47e8c
commit 4731c1a835
5 changed files with 27 additions and 20 deletions

View File

@@ -399,10 +399,11 @@ def compile_vicuna_layer(
return ts_g return ts_g
path = "TheBloke/vicuna-7B-1.1-HF" def get_model_and_tokenizer(path="TheBloke/vicuna-7B-1.1-HF"):
kwargs = {"torch_dtype": torch.float} kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs) vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
return vicuna_model, tokenizer
def compile_to_vmfb(inputs, layers, is_first=True): def compile_to_vmfb(inputs, layers, is_first=True):
@@ -577,7 +578,7 @@ def get_sharded_model():
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess # SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it # please don't change it
SAMPLE_INPUT_LEN = 137 SAMPLE_INPUT_LEN = 137
global vicuna_model vicuna_model = get_model_and_tokenizer()[0]
placeholder_input0 = ( placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
@@ -611,6 +612,7 @@ if __name__ == "__main__":
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n" prologue_prompt = "ASSISTANT:\n"
sharded_model = get_sharded_model() sharded_model = get_sharded_model()
tokenizer = get_model_and_tokenizer()[1]
past_key_values = None past_key_values = None
while True: while True:
print("\n\n") print("\n\n")

View File

@@ -24,12 +24,6 @@ from shark.shark_inference import SharkInference
from pathlib import Path from pathlib import Path
model_path = "stabilityai/stablelm-tuned-alpha-3b"
tok = AutoTokenizer.from_pretrained(model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
print(f"Sucessfully loaded the tokenizer to the memory")
class StopOnTokens(StoppingCriteria): class StopOnTokens(StoppingCriteria):
def __call__( def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
@@ -246,9 +240,12 @@ system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
""" """
input_ids = torch.randint(3, (1, 256)) def get_tokenizer():
attention_mask = torch.randint(3, (1, 256)) model_path = "stabilityai/stablelm-tuned-alpha-3b"
sharkModel = 0 tok = AutoTokenizer.from_pretrained(model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
print(f"Sucessfully loaded the tokenizer to the memory")
return tok
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256") # sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
@@ -263,7 +260,12 @@ def generate(
num_beams, num_beams,
stopping_criteria, stopping_criteria,
sharkStableLM, sharkStableLM,
tok=None,
input_ids=torch.randint(3, (1, 256)),
attention_mask=torch.randint(3, (1, 256)),
): ):
if tok == None:
tok = get_tokenizer
# Construct the input message string for the model by concatenating the current system message and conversation history # Construct the input message string for the model by concatenating the current system message and conversation history
# Tokenize the messages string # Tokenize the messages string
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256") # sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")

View File

@@ -31,6 +31,7 @@ datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark') datas += collect_data_files('shark')
datas += collect_data_files('tkinter') datas += collect_data_files('tkinter')
datas += collect_data_files('webview') datas += collect_data_files('webview')
datas += collect_data_files('sentencepiece')
datas += [ datas += [
( 'src/utils/resources/prompts.json', 'resources' ), ( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ), ( 'src/utils/resources/model_db.json', 'resources' ),

View File

@@ -5,8 +5,7 @@ from apps.language_models.scripts.stablelm import (
compile_stableLM, compile_stableLM,
StopOnTokens, StopOnTokens,
generate, generate,
sharkModel, get_tokenizer,
tok,
StableLMModel, StableLMModel,
) )
from transformers import ( from transformers import (
@@ -15,10 +14,6 @@ from transformers import (
StoppingCriteriaList, StoppingCriteriaList,
) )
from apps.stable_diffusion.web.ui.utils import available_devices from apps.stable_diffusion.web.ui.utils import available_devices
from apps.language_models.scripts.sharded_vicuna_fp32 import (
tokenizer,
get_sharded_model,
)
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version) start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
@@ -49,6 +44,11 @@ def chat(curr_system_message, history, model):
global sharded_model global sharded_model
global past_key_values global past_key_values
if "vicuna" in model: if "vicuna" in model:
from apps.language_models.scripts.sharded_vicuna_fp32 import (
tokenizer,
get_sharded_model,
)
SAMPLE_INPUT_LEN = 137 SAMPLE_INPUT_LEN = 137
curr_system_message = start_message_vicuna curr_system_message = start_message_vicuna
if sharded_model == 0: if sharded_model == 0:
@@ -100,6 +100,7 @@ def chat(curr_system_message, history, model):
global sharkModel global sharkModel
print("In chat") print("In chat")
if sharkModel == 0: if sharkModel == 0:
tok = get_tokenizer()
# sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb") # sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb")
m = AutoModelForCausalLM.from_pretrained( m = AutoModelForCausalLM.from_pretrained(
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32 "stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32

View File

@@ -28,6 +28,7 @@ scikit-image
pytorch_lightning # for runwayml models pytorch_lightning # for runwayml models
tk tk
pywebview pywebview
sentencepiece
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors # Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile pefile