Add support for StableLM-3B model (#2019)

* Add support for StableLM-3B model

* Add support for Quantized StableLM-3B model

* Update stablelm_pipeline.py
This commit is contained in:
Vivek Khandelwal
2023-12-12 22:39:50 +05:30
committed by GitHub
parent bf70e80d20
commit 3cc643b2de

View File

@@ -4,13 +4,49 @@ from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
from io import BytesIO
from pathlib import Path
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
get_vmfb_from_path,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.stablelm_model import (
StableLMModel,
)
import argparse
parser = argparse.ArgumentParser(
prog="stablelm runner",
description="runs a StableLM model",
)
parser.add_argument(
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
"--stablelm_vmfb_path", default=None, help="path to StableLM's vmfb"
)
parser.add_argument(
"--stablelm_mlir_path",
default=None,
help="path to StableLM's mlir file",
)
parser.add_argument(
"--use_precompiled_model",
default=True,
action=argparse.BooleanOptionalAction,
help="use the precompiled vmfb",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=True,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication token for stablelm-3B model.",
)
class StopOnTokens(StoppingCriteria):
@@ -29,7 +65,7 @@ class SharkStableLM(SharkLLMBase):
self,
model_name,
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
max_num_tokens=512,
max_num_tokens=256,
device="cuda",
precision="fp32",
debug="False",
@@ -37,6 +73,14 @@ class SharkStableLM(SharkLLMBase):
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
if precision != "int4" and args.hf_auth_token == None:
raise ValueError(
""" HF auth token required for StableLM-3B. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = args.hf_auth_token
self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
@@ -50,9 +94,23 @@ class SharkStableLM(SharkLLMBase):
return False
def get_src_model(self):
kwargs = {}
if self.precision == "int4":
self.hf_model_path = "TheBloke/stablelm-zephyr-3b-GPTQ"
from transformers import GPTQConfig
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["device_map"] = "cpu"
print("[DEBUG] Loading Model")
model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, torch_dtype=torch.float32
self.hf_model_path,
trust_remote_code=True,
torch_dtype=torch.float32,
use_auth_token=self.hf_auth_token,
**kwargs,
)
print("[DEBUG] Model loaded successfully")
return model
def get_model_inputs(self):
@@ -61,9 +119,7 @@ class SharkStableLM(SharkLLMBase):
return input_ids, attention_mask
def compile(self):
tmp_model_name = (
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
)
tmp_model_name = f"{self.model_name}_linalg_{self.precision}_seqLen{self.max_sequence_len}"
# device = "cuda" # "cpu"
# TODO: vmfb and mlir name should include precision and device
@@ -83,13 +139,19 @@ class SharkStableLM(SharkLLMBase):
print(
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
if not mlir_path.exists():
model = StableLMModel(self.get_src_model())
model_inputs = self.get_model_inputs()
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
from shark.shark_importer import import_with_fx
ts_graph = import_with_fx(
model,
model_inputs,
is_f16=True if self.precision in ["fp16"] else False,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
@@ -100,15 +162,16 @@ class SharkStableLM(SharkLLMBase):
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(tmp_model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
print("Saved mlir at: ", mlir_path)
f_.close()
del bytecode
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
mlir_module=mlir_path, device=self.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
@@ -120,14 +183,22 @@ class SharkStableLM(SharkLLMBase):
return shark_module
def get_tokenizer(self):
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
tok = AutoTokenizer.from_pretrained(
self.hf_model_path,
use_auth_token=self.hf_auth_token,
)
tok.add_special_tokens({"pad_token": "<PAD>"})
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
return tok
def generate(self, prompt):
words_list = []
import time
start = time.time()
count = 0
for i in range(self.max_num_tokens):
count = count + 1
params = {
"new_text": prompt,
}
@@ -145,6 +216,12 @@ class SharkStableLM(SharkLLMBase):
if detok == "":
break
prompt = prompt + detok
end = time.time()
print(
"\n\nTime taken is {:.2f} tokens/second\n".format(
count / (end - start)
)
)
return words_list
def generate_new_token(self, params):
@@ -178,10 +255,46 @@ class SharkStableLM(SharkLLMBase):
return ret_dict
# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
if __name__ == "__main__":
args = parser.parse_args()
stable_lm = SharkStableLM(
model_name="stablelm_zephyr_3b",
hf_model_path="stabilityai/stablelm-zephyr-3b",
device=args.device,
precision=args.precision,
)
default_prompt_text = "The weather is always wonderful"
continue_execution = True
print("\n-----\nScript executing for the following config: \n")
print("StableLM Model: ", stable_lm.hf_model_path)
print("Precision: ", args.precision)
print("Device: ", args.device)
while continue_execution:
use_default_prompt = input(
"\nDo you wish to use the default prompt text? Y/N ?: "
)
if use_default_prompt in ["Y", "y"]:
prompt = default_prompt_text
else:
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)
res_str = stable_lm.generate(prompt)
torch.cuda.empty_cache()
import gc
gc.collect()
print(
"\n\n-----\nHere's the complete formatted result: \n\n",
prompt + "".join(res_str),
)
continue_execution = input(
"\nDo you wish to run script one more time? Y/N ?: "
)
continue_execution = (
True if continue_execution in ["Y", "y"] else False
)