Add smoothquant OPT to examples. (#1922)

This commit is contained in:
Ean Garvey
2023-10-27 12:32:12 -05:00
committed by GitHub
parent 679a452139
commit 98244232dd
2 changed files with 52 additions and 10 deletions

View File

@@ -10,7 +10,9 @@ from typing import Iterable
def create_module(model_name, tokenizer, device, args):
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
opt_base_model = OPTForCausalLM.from_pretrained(
model_name, allow_mismatched_sizes=True
)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
@@ -112,6 +114,11 @@ def parse_args():
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
"mit-han-lab/opt-125m-smoothquant",
"mit-han-lab/opt-1.3b-smoothquant",
"mit-han-lab/opt-2.7b-smoothquant",
"mit-han-lab/opt-6.7b-smoothquant",
"mit-han-lab/opt-13b-smoothquant",
],
default="facebook/opt-1.3b",
)
@@ -162,7 +169,11 @@ def generate_tokens(
if __name__ == "__main__":
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
if "smoothquant" in args.model_name:
token_model_name = f"facebook/opt-{args.model_name.split('-')[3]}"
else:
token_model_name = args.model_name
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
opt_fs_name = "-".join(
"_".join(args.model_name.split("/")[1].split("-")).split(".")
)

View File

@@ -60,7 +60,9 @@ def import_mlir_module(
device: str,
max_seq_len: int,
):
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
opt_base_model = OPTForCausalLM.from_pretrained(
model_name, ignore_mismatched_sizes=True
)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
@@ -130,13 +132,14 @@ def create_vmfb_module(
def load_shark_model(
model_name: str,
token_model_name: str,
max_seq_len: int,
recompile_shark: bool,
plugin_path: str = [],
) -> ModelWrapper:
opt_fs_name = get_opt_fs_name(model_name)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
if recompile_shark or not os.path.isfile(vmfb_name):
print(f"vmfb not found. compiling and saving to {vmfb_name}")
create_vmfb_module(
@@ -158,10 +161,12 @@ def run_shark_model(model_wrapper: ModelWrapper, tokens):
return model_wrapper.model("forward", tokens)
def load_huggingface_model(model_name: str) -> ModelWrapper:
def load_huggingface_model(
model_name: str, token_model_name: str
) -> ModelWrapper:
return ModelWrapper(
model=OPTForCausalLM.from_pretrained(model_name),
tokenizer=AutoTokenizer.from_pretrained(model_name),
tokenizer=AutoTokenizer.from_pretrained(token_model_name),
)
@@ -177,11 +182,14 @@ def save_json(data, filename):
def collect_huggingface_logits(
model_name: str, max_seq_len: int, to_save_json: bool
model_name: str,
token_model_name: str,
max_seq_len: int,
to_save_json: bool,
) -> Tuple[float, float]:
# Load
t0 = time.time()
model_wrapper = load_huggingface_model(model_name)
model_wrapper = load_huggingface_model(model_name, token_model_name)
load_time = time.time() - t0
print("--- Took {} seconds to load Huggingface.".format(load_time))
load_memory_info = get_memory_info()
@@ -225,6 +233,7 @@ def collect_huggingface_logits(
def collect_shark_logits(
model_name: str,
token_model_name: str,
max_seq_len: int,
recompile_shark: bool,
to_save_json: bool,
@@ -233,7 +242,7 @@ def collect_shark_logits(
# Load
t0 = time.time()
model_wrapper = load_shark_model(
model_name, max_seq_len, recompile_shark, plugin_path
model_name, token_model_name, max_seq_len, recompile_shark, plugin_path
)
load_time = time.time() - t0
print("--- Took {} seconds to load Shark.".format(load_time))
@@ -315,6 +324,11 @@ def parse_args():
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
"mit-han-lab/opt-125m-smoothquant",
"mit-han-lab/opt-1.3b-smoothquant",
"mit-han-lab/opt-2.7b-smoothquant",
"mit-han-lab/opt-6.7b-smoothquant",
"mit-han-lab/opt-13b-smoothquant",
],
default="facebook/opt-1.3b",
)
@@ -337,6 +351,12 @@ def parse_args():
type=str,
default=None,
)
parser.add_argument(
"--token-model-name",
help="HF ID to create tokenizer.",
type=str,
default=None,
)
args = parser.parse_args()
print("args={}".format(args))
return args
@@ -344,9 +364,17 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
if args.token_model_name == None:
if "smoothquant" in args.model_name:
args.token_model_name = (
f"facebook/opt-{args.model_name.split('-')[3]}"
)
else:
args.token_model_name = args.model_name
if args.platform == PLATFORM_SHARK:
shark_report = collect_shark_logits(
args.model_name,
args.token_model_name,
args.max_seq_len,
args.recompile_shark,
args.save_json,
@@ -355,6 +383,9 @@ if __name__ == "__main__":
print("# Summary: {}".format(json.dumps(shark_report)))
else:
huggingface_report = collect_huggingface_logits(
args.model_name, args.max_seq_len, args.save_json
args.model_name,
args.token_model_name,
args.max_seq_len,
args.save_json,
)
print("# Summary: {}".format(json.dumps(huggingface_report)))