mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Updates to opt_causallm example (#1905)
* Updates to opt_causallm example * Fixup opt_perf_comparison.py * Use same filenames across opt examples.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -10,21 +11,16 @@ from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
OPT_MODEL = "opt-1.3b"
|
||||
OPT_FS_NAME = "opt-1_3b"
|
||||
MAX_SEQUENCE_LENGTH = 128
|
||||
MAX_NEW_TOKENS = 60
|
||||
|
||||
|
||||
def create_module(model_name, tokenizer, device):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
|
||||
def create_module(model_name, tokenizer, device, args):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
|
||||
opt_base_model.eval()
|
||||
opt_model = OPTForCausalLMModel(opt_base_model)
|
||||
encoded_inputs = tokenizer(
|
||||
"What is the meaning of life?",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
max_length=args.max_seq_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
@@ -33,8 +29,11 @@ def create_module(model_name, tokenizer, device):
|
||||
)
|
||||
# np.save("model_inputs_0.npy", inputs[0])
|
||||
# np.save("model_inputs_1.npy", inputs[1])
|
||||
opt_fs_name = "-".join(
|
||||
"_".join(args.model_name.split("/")[1].split("-")).split(".")
|
||||
)
|
||||
|
||||
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
|
||||
mlir_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch.mlir"
|
||||
if os.path.isfile(mlir_path):
|
||||
print(f"Found .mlir from {mlir_path}")
|
||||
else:
|
||||
@@ -42,7 +41,7 @@ def create_module(model_name, tokenizer, device):
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=OPT_FS_NAME,
|
||||
model_name=opt_fs_name,
|
||||
return_str=True,
|
||||
)
|
||||
with open(mlir_path, "w") as f:
|
||||
@@ -57,7 +56,7 @@ def create_module(model_name, tokenizer, device):
|
||||
is_benchmark=False,
|
||||
)
|
||||
|
||||
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu"
|
||||
shark_module.save_module(module_name=vmfb_name, debug=False)
|
||||
vmfb_path = vmfb_name + ".vmfb"
|
||||
return vmfb_path
|
||||
@@ -71,11 +70,11 @@ def shouldStop(tokens):
|
||||
return False
|
||||
|
||||
|
||||
def generate_new_token(shark_model, tokenizer, new_text):
|
||||
def generate_new_token(shark_model, tokenizer, new_text, args):
|
||||
model_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
max_length=args.max_seq_len,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
@@ -104,18 +103,56 @@ def generate_new_token(shark_model, tokenizer, new_text):
|
||||
return ret_dict
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max-seq-len", type=int, default=32)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
help="Model name",
|
||||
type=str,
|
||||
choices=[
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-350m",
|
||||
"facebook/opt-1.3b",
|
||||
"facebook/opt-6.7b",
|
||||
],
|
||||
default="facebook/opt-1.3b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recompile",
|
||||
help="If set, recompiles MLIR -> .vmfb",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin-path",
|
||||
help="path to executable plugin",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("args={}".format(args))
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/" + OPT_MODEL, use_fast=False
|
||||
args = parse_args()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
|
||||
opt_fs_name = "-".join(
|
||||
"_".join(args.model_name.split("/")[1].split("-")).split(".")
|
||||
)
|
||||
vmfb_path = (
|
||||
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-task.vmfb"
|
||||
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
|
||||
if args.plugin_path is not None:
|
||||
rt_flags = [f"--executable_plugin={args.plugin_path}"]
|
||||
else:
|
||||
rt_flags = []
|
||||
opt_shark_module = SharkInference(
|
||||
mlir_module=None, device="cpu-task", rt_flags=rt_flags
|
||||
)
|
||||
opt_shark_module = SharkInference(mlir_module=None, device="cpu-task")
|
||||
if os.path.isfile(vmfb_path):
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-task")
|
||||
vmfb_path = create_module(args.model_name, tokenizer, "cpu-task", args)
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
while True:
|
||||
try:
|
||||
@@ -123,9 +160,9 @@ if __name__ == "__main__":
|
||||
new_text_init = new_text
|
||||
words_list = []
|
||||
|
||||
for i in range(MAX_NEW_TOKENS):
|
||||
for i in range(args.max_seq_len):
|
||||
generated_token_op = generate_new_token(
|
||||
opt_shark_module, tokenizer, new_text
|
||||
opt_shark_module, tokenizer, new_text, args
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
|
||||
@@ -147,7 +147,7 @@ def load_shark_model(
|
||||
plugin_path: str = [],
|
||||
) -> ModelWrapper:
|
||||
opt_fs_name = get_opt_fs_name(model_name)
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb"
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
if recompile_shark or not os.path.isfile(vmfb_name):
|
||||
print(f"vmfb not found. compiling and saving to {vmfb_name}")
|
||||
@@ -344,7 +344,7 @@ def parse_args():
|
||||
default=PLATFORM_SHARK,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin_path",
|
||||
"--plugin-path",
|
||||
help="path to executable plugin",
|
||||
type=str,
|
||||
default=None,
|
||||
|
||||
Reference in New Issue
Block a user