mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add opt_causallm_samples.py. (#1916)
This commit is contained in:
@@ -10,6 +10,7 @@ from shark.iree_utils._common import (
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
def create_module(model_name, tokenizer, device, args):
|
||||
@@ -70,11 +71,11 @@ def shouldStop(tokens):
|
||||
return False
|
||||
|
||||
|
||||
def generate_new_token(shark_model, tokenizer, new_text, args):
|
||||
def generate_new_token(shark_module, tokenizer, new_text, max_seq_len: int):
|
||||
model_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
max_length=args.max_seq_len,
|
||||
max_length=max_seq_len,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
@@ -83,7 +84,7 @@ def generate_new_token(shark_model, tokenizer, new_text, args):
|
||||
model_inputs["attention_mask"],
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
output = shark_model("forward", inputs)
|
||||
output = shark_module("forward", inputs)
|
||||
output = torch.FloatTensor(output[0])
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
@@ -135,6 +136,34 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
def generate_tokens(
|
||||
opt_shark_module: "SharkInference",
|
||||
tokenizer,
|
||||
input_text: str,
|
||||
max_output_len: int,
|
||||
print_intermediate_results: True,
|
||||
) -> Iterable[str]:
|
||||
words_list = []
|
||||
new_text = input_text
|
||||
try:
|
||||
for _ in range(max_output_len):
|
||||
generated_token_op = generate_new_token(
|
||||
opt_shark_module, tokenizer, new_text, max_output_len
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
if generated_token_op["stop_generation"]:
|
||||
break
|
||||
if print_intermediate_results:
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text += detok
|
||||
except KeyboardInterrupt as e:
|
||||
print("Exiting token generation.")
|
||||
return words_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
|
||||
@@ -155,25 +184,7 @@ if __name__ == "__main__":
|
||||
vmfb_path = create_module(args.model_name, tokenizer, "cpu-task", args)
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence to complete:")
|
||||
new_text_init = new_text
|
||||
words_list = []
|
||||
|
||||
for i in range(args.max_seq_len):
|
||||
generated_token_op = generate_new_token(
|
||||
opt_shark_module, tokenizer, new_text, args
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
if stop_generation:
|
||||
break
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text = new_text + detok
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
break
|
||||
input_text = input("Give me a sentence to complete:")
|
||||
generate_tokens(
|
||||
opt_shark_module, tokenizer, input_text, args.max_seq_len
|
||||
)
|
||||
|
||||
74
tank/examples/opt/opt_causallm_samples.py
Normal file
74
tank/examples/opt/opt_causallm_samples.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import opt_causallm
|
||||
import opt_util
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max-seq-len", type=int, default=32)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
help="Model name",
|
||||
type=str,
|
||||
choices=[
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-350m",
|
||||
"facebook/opt-1.3b",
|
||||
"facebook/opt-6.7b",
|
||||
],
|
||||
default="facebook/opt-1.3b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recompile",
|
||||
help="If set, recompiles MLIR -> .vmfb",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin-path",
|
||||
help="path to executable plugin",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("args={}".format(args))
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
|
||||
opt_fs_name = "-".join(
|
||||
"_".join(args.model_name.split("/")[1].split("-")).split(".")
|
||||
)
|
||||
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
|
||||
if args.plugin_path is not None:
|
||||
rt_flags = [f"--executable_plugin={args.plugin_path}"]
|
||||
else:
|
||||
rt_flags = []
|
||||
opt_shark_module = SharkInference(
|
||||
mlir_module=None, device="cpu-task", rt_flags=rt_flags
|
||||
)
|
||||
if os.path.isfile(vmfb_path):
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
vmfb_path = opt_causallm.create_module(
|
||||
args.model_name, tokenizer, "cpu-task", args
|
||||
)
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
|
||||
for prompt in opt_util.PROMPTS:
|
||||
print("\n\nprompt: {}".format(prompt))
|
||||
response = opt_causallm.generate_tokens(
|
||||
opt_shark_module,
|
||||
tokenizer,
|
||||
prompt,
|
||||
args.max_seq_len,
|
||||
print_intermediate_results=False,
|
||||
)
|
||||
print("reponse: {}".format("".join(response)))
|
||||
@@ -22,6 +22,7 @@ import time
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
from opt_util import PROMPTS
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
@@ -44,19 +45,6 @@ REPORT_LOAD_VIRTUAL_MEMORY_MB = "load_virtual_MB"
|
||||
REPORT_RUN_PHYSICAL_MEMORY_MB = "run_physical_MB"
|
||||
REPORT_RUN_VIRTUAL_MEMORY_MB = "run_virtual_MB"
|
||||
|
||||
PROMPTS = [
|
||||
"What is the meaning of life?",
|
||||
"Tell me something you don't know.",
|
||||
"What does Xilinx do?",
|
||||
"What is the mass of earth?",
|
||||
"What is a poem?",
|
||||
"What is recursion?",
|
||||
"Tell me a one line joke.",
|
||||
"Who is Gilgamesh?",
|
||||
"Tell me something about cryptocurrency.",
|
||||
"How did it all begin?",
|
||||
]
|
||||
|
||||
ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])
|
||||
|
||||
|
||||
|
||||
12
tank/examples/opt/opt_util.py
Normal file
12
tank/examples/opt/opt_util.py
Normal file
@@ -0,0 +1,12 @@
|
||||
PROMPTS = [
|
||||
"What is the meaning of life?",
|
||||
"Tell me something you don't know.",
|
||||
"What does Xilinx do?",
|
||||
"What is the mass of earth?",
|
||||
"What is a poem?",
|
||||
"What is recursion?",
|
||||
"Tell me a one line joke.",
|
||||
"Who is Gilgamesh?",
|
||||
"Tell me something about cryptocurrency.",
|
||||
"How did it all begin?",
|
||||
]
|
||||
Reference in New Issue
Block a user