From 4731c1a835a3847c11fde5aaa62943acccbd0ba1 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Fri, 12 May 2023 21:11:45 -0500 Subject: [PATCH] prevent loading tokenizer on import (#1432) also adds sentencepiece dep for exe moved vicuna imports to after an if statement in general we should avoid importing files that load whole models as global variables --- .../scripts/sharded_vicuna_fp32.py | 12 ++++++----- apps/language_models/scripts/stablelm.py | 20 ++++++++++--------- apps/stable_diffusion/shark_sd.spec | 1 + apps/stable_diffusion/web/ui/stablelm_ui.py | 13 ++++++------ requirements.txt | 1 + 5 files changed, 27 insertions(+), 20 deletions(-) diff --git a/apps/language_models/scripts/sharded_vicuna_fp32.py b/apps/language_models/scripts/sharded_vicuna_fp32.py index 931956bc..b20c6e01 100644 --- a/apps/language_models/scripts/sharded_vicuna_fp32.py +++ b/apps/language_models/scripts/sharded_vicuna_fp32.py @@ -399,10 +399,11 @@ def compile_vicuna_layer( return ts_g -path = "TheBloke/vicuna-7B-1.1-HF" -kwargs = {"torch_dtype": torch.float} -vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs) -tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False) +def get_model_and_tokenizer(path="TheBloke/vicuna-7B-1.1-HF"): + kwargs = {"torch_dtype": torch.float} + vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False) + return vicuna_model, tokenizer def compile_to_vmfb(inputs, layers, is_first=True): @@ -577,7 +578,7 @@ def get_sharded_model(): # SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess # please don't change it SAMPLE_INPUT_LEN = 137 - global vicuna_model + vicuna_model = get_model_and_tokenizer()[0] placeholder_input0 = ( torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), @@ -611,6 +612,7 @@ if __name__ == "__main__": prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" prologue_prompt = "ASSISTANT:\n" sharded_model = get_sharded_model() + tokenizer = get_model_and_tokenizer()[1] past_key_values = None while True: print("\n\n") diff --git a/apps/language_models/scripts/stablelm.py b/apps/language_models/scripts/stablelm.py index c481a20a..ac57b54f 100644 --- a/apps/language_models/scripts/stablelm.py +++ b/apps/language_models/scripts/stablelm.py @@ -24,12 +24,6 @@ from shark.shark_inference import SharkInference from pathlib import Path -model_path = "stabilityai/stablelm-tuned-alpha-3b" -tok = AutoTokenizer.from_pretrained(model_path) -tok.add_special_tokens({"pad_token": ""}) -print(f"Sucessfully loaded the tokenizer to the memory") - - class StopOnTokens(StoppingCriteria): def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs @@ -246,9 +240,12 @@ system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version) """ -input_ids = torch.randint(3, (1, 256)) -attention_mask = torch.randint(3, (1, 256)) -sharkModel = 0 +def get_tokenizer(): + model_path = "stabilityai/stablelm-tuned-alpha-3b" + tok = AutoTokenizer.from_pretrained(model_path) + tok.add_special_tokens({"pad_token": ""}) + print(f"Sucessfully loaded the tokenizer to the memory") + return tok # sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256") @@ -263,7 +260,12 @@ def generate( num_beams, stopping_criteria, sharkStableLM, + tok=None, + input_ids=torch.randint(3, (1, 256)), + attention_mask=torch.randint(3, (1, 256)), ): + if tok == None: + tok = get_tokenizer # Construct the input message string for the model by concatenating the current system message and conversation history # Tokenize the messages string # sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256") diff --git a/apps/stable_diffusion/shark_sd.spec b/apps/stable_diffusion/shark_sd.spec index 18040048..8e75d1a4 100644 --- a/apps/stable_diffusion/shark_sd.spec +++ b/apps/stable_diffusion/shark_sd.spec @@ -31,6 +31,7 @@ datas += collect_data_files('google-cloud-storage') datas += collect_data_files('shark') datas += collect_data_files('tkinter') datas += collect_data_files('webview') +datas += collect_data_files('sentencepiece') datas += [ ( 'src/utils/resources/prompts.json', 'resources' ), ( 'src/utils/resources/model_db.json', 'resources' ), diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 04642dd6..bb2a7503 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -5,8 +5,7 @@ from apps.language_models.scripts.stablelm import ( compile_stableLM, StopOnTokens, generate, - sharkModel, - tok, + get_tokenizer, StableLMModel, ) from transformers import ( @@ -15,10 +14,6 @@ from transformers import ( StoppingCriteriaList, ) from apps.stable_diffusion.web.ui.utils import available_devices -from apps.language_models.scripts.sharded_vicuna_fp32 import ( - tokenizer, - get_sharded_model, -) start_message = """<|SYSTEM|># StableLM Tuned (Alpha version) - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. @@ -49,6 +44,11 @@ def chat(curr_system_message, history, model): global sharded_model global past_key_values if "vicuna" in model: + from apps.language_models.scripts.sharded_vicuna_fp32 import ( + tokenizer, + get_sharded_model, + ) + SAMPLE_INPUT_LEN = 137 curr_system_message = start_message_vicuna if sharded_model == 0: @@ -100,6 +100,7 @@ def chat(curr_system_message, history, model): global sharkModel print("In chat") if sharkModel == 0: + tok = get_tokenizer() # sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb") m = AutoModelForCausalLM.from_pretrained( "stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32 diff --git a/requirements.txt b/requirements.txt index e2a91e44..bd8332e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,6 +28,7 @@ scikit-image pytorch_lightning # for runwayml models tk pywebview +sentencepiece # Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors pefile