Files
AMD-SHARK-Studio/tank/examples/opt/opt_perf_comparison.py
2023-10-27 12:32:12 -05:00

392 lines
12 KiB
Python

"""
Script for comparing OPT model performance between SHARK and Huggingface
PyTorch.
Usage Example:
python opt_perf_comparison.py --max-seq-len=32 --model-name=facebook/opt-125m \
--platform=shark
python opt_perf_comparison.py --max-seq-len=512 --model-name=facebook/opt-1.3b \
--platform=shark
See parse_args() below for command line argument usage.
"""
import argparse
import collections
import json
import os
import psutil
import time
import numpy as np
from typing import Tuple
from opt_util import PROMPTS
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
from shark_opt_wrapper import OPTForCausalLMModel
from shark.parser import shark_args
import iree.compiler as ireec
DEVICE = "cpu"
PLATFORM_SHARK = "shark"
PLATFORM_HUGGINGFACE = "huggingface"
# Dict keys for reports.
REPORT_PLATFORM = "platform"
REPORT_MODEL_NAME = "model"
REPORT_MAX_SEQ_LEN = "max_seq_len"
REPORT_LOAD_TIME = "load_time_sec"
REPORT_RUN_TIME = "run_time_sec"
REPORT_LOAD_PHYSICAL_MEMORY_MB = "load_physical_MB"
REPORT_LOAD_VIRTUAL_MEMORY_MB = "load_virtual_MB"
REPORT_RUN_PHYSICAL_MEMORY_MB = "run_physical_MB"
REPORT_RUN_VIRTUAL_MEMORY_MB = "run_virtual_MB"
ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])
def get_memory_info():
pid = os.getpid()
process = psutil.Process(pid)
return process.memory_info()
def import_mlir_module(
model_name: str,
tokenizer,
device: str,
max_seq_len: int,
):
opt_base_model = OPTForCausalLM.from_pretrained(
model_name, ignore_mismatched_sizes=True
)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
PROMPTS[0],
padding="max_length",
truncation=True,
max_length=max_seq_len,
return_tensors="pt",
)
inputs = (
encoded_inputs["input_ids"],
encoded_inputs["attention_mask"],
)
# np.save("model_inputs_0.npy", inputs[0])
# np.save("model_inputs_1.npy", inputs[1])
opt_fs_name = get_opt_fs_name(model_name)
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
(model_mlir, func_name) = import_with_fx(
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=opt_fs_name,
return_str=True,
)
with open(mlir_path, "w") as f:
f.write(model_mlir)
print(f"Saved mlir at {mlir_path}")
def create_vmfb_module(
model_name: str,
tokenizer,
device: str,
max_seq_len: int,
recompile_shark: bool,
):
opt_fs_name = get_opt_fs_name(model_name)
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
# If MLIR has already been loaded and recompilation is not requested, use
# the loaded MLIR file.
has_mlir = os.path.isfile(mlir_path)
# The purpose of recompile_shark is to measure compilation time; the
# compilation time can be correctly measured only when MLIR has already been
# loaded.
assert not recompile_shark or has_mlir
if not has_mlir:
import_mlir_module(
model_name,
tokenizer,
device,
max_seq_len,
)
shark_module = SharkInference(
mlir_path,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=False,
rt_flags=[],
)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}"
shark_module.save_module(module_name=vmfb_name)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path
def load_shark_model(
model_name: str,
token_model_name: str,
max_seq_len: int,
recompile_shark: bool,
plugin_path: str = [],
) -> ModelWrapper:
opt_fs_name = get_opt_fs_name(model_name)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb"
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
if recompile_shark or not os.path.isfile(vmfb_name):
print(f"vmfb not found. compiling and saving to {vmfb_name}")
create_vmfb_module(
model_name, tokenizer, DEVICE, max_seq_len, recompile_shark
)
if plugin_path is not None:
rt_flags = [f"--executable_plugin={plugin_path}"]
else:
rt_flags = []
shark_module = SharkInference(
mlir_module=None, device="cpu-task", rt_flags=rt_flags
)
shark_module.load_module(vmfb_name)
return ModelWrapper(model=shark_module, tokenizer=tokenizer)
def run_shark_model(model_wrapper: ModelWrapper, tokens):
# Generate logits output of OPT model.
return model_wrapper.model("forward", tokens)
def load_huggingface_model(
model_name: str, token_model_name: str
) -> ModelWrapper:
return ModelWrapper(
model=OPTForCausalLM.from_pretrained(model_name),
tokenizer=AutoTokenizer.from_pretrained(token_model_name),
)
def run_huggingface_model(model_wrapper: ModelWrapper, tokens):
return model_wrapper.model.forward(
tokens.input_ids, tokens.attention_mask, return_dict=False
)
def save_json(data, filename):
with open(filename, "w") as file:
json.dump(data, file)
def collect_huggingface_logits(
model_name: str,
token_model_name: str,
max_seq_len: int,
to_save_json: bool,
) -> Tuple[float, float]:
# Load
t0 = time.time()
model_wrapper = load_huggingface_model(model_name, token_model_name)
load_time = time.time() - t0
print("--- Took {} seconds to load Huggingface.".format(load_time))
load_memory_info = get_memory_info()
results = []
tokenized_prompts = []
for prompt in PROMPTS:
tokens = model_wrapper.tokenizer(
prompt,
padding="max_length",
max_length=max_seq_len,
truncation=True,
return_tensors="pt",
)
tokenized_prompts.append(tokens)
# Run
t0 = time.time()
for idx, tokens in enumerate(tokenized_prompts):
print("prompt: {}".format(PROMPTS[idx]))
logits = run_huggingface_model(model_wrapper, tokens)
if to_save_json:
results.append([PROMPTS[idx], logits[0].tolist()])
run_time = time.time() - t0
print("--- Took {} seconds to run Huggingface.".format(run_time))
if to_save_json:
save_json(results, "/tmp/huggingface.json")
run_memory_info = get_memory_info()
return {
REPORT_PLATFORM: PLATFORM_HUGGINGFACE,
REPORT_MODEL_NAME: model_name,
REPORT_MAX_SEQ_LEN: max_seq_len,
REPORT_LOAD_TIME: load_time,
REPORT_RUN_TIME: run_time / len(PROMPTS),
REPORT_LOAD_PHYSICAL_MEMORY_MB: load_memory_info.rss >> 20,
REPORT_LOAD_VIRTUAL_MEMORY_MB: load_memory_info.vms >> 20,
REPORT_RUN_PHYSICAL_MEMORY_MB: run_memory_info.rss >> 20,
REPORT_RUN_VIRTUAL_MEMORY_MB: run_memory_info.vms >> 20,
}
def collect_shark_logits(
model_name: str,
token_model_name: str,
max_seq_len: int,
recompile_shark: bool,
to_save_json: bool,
plugin_path: str,
) -> Tuple[float, float]:
# Load
t0 = time.time()
model_wrapper = load_shark_model(
model_name, token_model_name, max_seq_len, recompile_shark, plugin_path
)
load_time = time.time() - t0
print("--- Took {} seconds to load Shark.".format(load_time))
load_memory_info = get_memory_info()
results = []
tokenized_prompts = []
for prompt in PROMPTS:
tokens = model_wrapper.tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=max_seq_len,
return_tensors="pt",
)
inputs = (
tokens["input_ids"],
tokens["attention_mask"],
)
tokenized_prompts.append(inputs)
# Run
t0 = time.time()
for idx, tokens in enumerate(tokenized_prompts):
print("prompt: {}".format(PROMPTS[idx]))
logits = run_shark_model(model_wrapper, tokens)
lst = [e.tolist() for e in logits]
if to_save_json:
results.append([PROMPTS[idx], lst])
run_time = time.time() - t0
print("--- Took {} seconds to run Shark.".format(run_time))
if to_save_json:
save_json(results, "/tmp/shark.json")
platform_postfix = "-compile" if recompile_shark else "-precompiled"
run_memory_info = get_memory_info()
return {
REPORT_PLATFORM: PLATFORM_SHARK + platform_postfix,
REPORT_MODEL_NAME: model_name,
REPORT_MAX_SEQ_LEN: max_seq_len,
REPORT_LOAD_TIME: load_time,
REPORT_RUN_TIME: run_time / len(PROMPTS),
REPORT_LOAD_PHYSICAL_MEMORY_MB: load_memory_info.rss >> 20,
REPORT_LOAD_VIRTUAL_MEMORY_MB: load_memory_info.vms >> 20,
REPORT_RUN_PHYSICAL_MEMORY_MB: run_memory_info.rss >> 20,
REPORT_RUN_VIRTUAL_MEMORY_MB: run_memory_info.vms >> 20,
}
def get_opt_fs_name(model_name: str) -> str:
"""Cleanses the model name ino a file system-friendly name.
Example: get_opt_fs_name('facebook/opt-1.3b') == 'opt_1-3b'
"""
slash_split = model_name.split("/")
assert 1 <= len(slash_split) <= 2, "There should be at most one slash."
model_name = slash_split[-1]
for src_pattern, dest_pattern in (("-", "_"), (".", "-")):
model_name = model_name.replace(src_pattern, dest_pattern)
return model_name
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--save-json",
help="If set, saves output JSON.",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--max-seq-len", help="Max sequence length", type=int, default=32
)
parser.add_argument(
"--model-name",
help="Model name",
type=str,
choices=[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
"mit-han-lab/opt-125m-smoothquant",
"mit-han-lab/opt-1.3b-smoothquant",
"mit-han-lab/opt-2.7b-smoothquant",
"mit-han-lab/opt-6.7b-smoothquant",
"mit-han-lab/opt-13b-smoothquant",
],
default="facebook/opt-1.3b",
)
parser.add_argument(
"--recompile-shark",
help="If set, recompiles MLIR",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--platform",
help="Either shark or huggingface",
type=str,
choices=[PLATFORM_SHARK, PLATFORM_HUGGINGFACE],
default=PLATFORM_SHARK,
)
parser.add_argument(
"--plugin-path",
help="path to executable plugin",
type=str,
default=None,
)
parser.add_argument(
"--token-model-name",
help="HF ID to create tokenizer.",
type=str,
default=None,
)
args = parser.parse_args()
print("args={}".format(args))
return args
if __name__ == "__main__":
args = parse_args()
if args.token_model_name == None:
if "smoothquant" in args.model_name:
args.token_model_name = (
f"facebook/opt-{args.model_name.split('-')[3]}"
)
else:
args.token_model_name = args.model_name
if args.platform == PLATFORM_SHARK:
shark_report = collect_shark_logits(
args.model_name,
args.token_model_name,
args.max_seq_len,
args.recompile_shark,
args.save_json,
args.plugin_path,
)
print("# Summary: {}".format(json.dumps(shark_report)))
else:
huggingface_report = collect_huggingface_logits(
args.model_name,
args.token_model_name,
args.max_seq_len,
args.save_json,
)
print("# Summary: {}".format(json.dumps(huggingface_report)))