diff --git a/tank/examples/opt/opt_causallm.py b/tank/examples/opt/opt_causallm.py index 7983c9ef..c84fa5a9 100644 --- a/tank/examples/opt/opt_causallm.py +++ b/tank/examples/opt/opt_causallm.py @@ -10,7 +10,9 @@ from typing import Iterable def create_module(model_name, tokenizer, device, args): - opt_base_model = OPTForCausalLM.from_pretrained(model_name) + opt_base_model = OPTForCausalLM.from_pretrained( + model_name, allow_mismatched_sizes=True + ) opt_base_model.eval() opt_model = OPTForCausalLMModel(opt_base_model) encoded_inputs = tokenizer( @@ -112,6 +114,11 @@ def parse_args(): "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-6.7b", + "mit-han-lab/opt-125m-smoothquant", + "mit-han-lab/opt-1.3b-smoothquant", + "mit-han-lab/opt-2.7b-smoothquant", + "mit-han-lab/opt-6.7b-smoothquant", + "mit-han-lab/opt-13b-smoothquant", ], default="facebook/opt-1.3b", ) @@ -162,7 +169,11 @@ def generate_tokens( if __name__ == "__main__": args = parse_args() - tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False) + if "smoothquant" in args.model_name: + token_model_name = f"facebook/opt-{args.model_name.split('-')[3]}" + else: + token_model_name = args.model_name + tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False) opt_fs_name = "-".join( "_".join(args.model_name.split("/")[1].split("-")).split(".") ) diff --git a/tank/examples/opt/opt_perf_comparison.py b/tank/examples/opt/opt_perf_comparison.py index 595a5608..21a1ee97 100644 --- a/tank/examples/opt/opt_perf_comparison.py +++ b/tank/examples/opt/opt_perf_comparison.py @@ -60,7 +60,9 @@ def import_mlir_module( device: str, max_seq_len: int, ): - opt_base_model = OPTForCausalLM.from_pretrained(model_name) + opt_base_model = OPTForCausalLM.from_pretrained( + model_name, ignore_mismatched_sizes=True + ) opt_base_model.eval() opt_model = OPTForCausalLMModel(opt_base_model) encoded_inputs = tokenizer( @@ -130,13 +132,14 @@ def create_vmfb_module( def load_shark_model( model_name: str, + token_model_name: str, max_seq_len: int, recompile_shark: bool, plugin_path: str = [], ) -> ModelWrapper: opt_fs_name = get_opt_fs_name(model_name) vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb" - tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) + tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False) if recompile_shark or not os.path.isfile(vmfb_name): print(f"vmfb not found. compiling and saving to {vmfb_name}") create_vmfb_module( @@ -158,10 +161,12 @@ def run_shark_model(model_wrapper: ModelWrapper, tokens): return model_wrapper.model("forward", tokens) -def load_huggingface_model(model_name: str) -> ModelWrapper: +def load_huggingface_model( + model_name: str, token_model_name: str +) -> ModelWrapper: return ModelWrapper( model=OPTForCausalLM.from_pretrained(model_name), - tokenizer=AutoTokenizer.from_pretrained(model_name), + tokenizer=AutoTokenizer.from_pretrained(token_model_name), ) @@ -177,11 +182,14 @@ def save_json(data, filename): def collect_huggingface_logits( - model_name: str, max_seq_len: int, to_save_json: bool + model_name: str, + token_model_name: str, + max_seq_len: int, + to_save_json: bool, ) -> Tuple[float, float]: # Load t0 = time.time() - model_wrapper = load_huggingface_model(model_name) + model_wrapper = load_huggingface_model(model_name, token_model_name) load_time = time.time() - t0 print("--- Took {} seconds to load Huggingface.".format(load_time)) load_memory_info = get_memory_info() @@ -225,6 +233,7 @@ def collect_huggingface_logits( def collect_shark_logits( model_name: str, + token_model_name: str, max_seq_len: int, recompile_shark: bool, to_save_json: bool, @@ -233,7 +242,7 @@ def collect_shark_logits( # Load t0 = time.time() model_wrapper = load_shark_model( - model_name, max_seq_len, recompile_shark, plugin_path + model_name, token_model_name, max_seq_len, recompile_shark, plugin_path ) load_time = time.time() - t0 print("--- Took {} seconds to load Shark.".format(load_time)) @@ -315,6 +324,11 @@ def parse_args(): "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-6.7b", + "mit-han-lab/opt-125m-smoothquant", + "mit-han-lab/opt-1.3b-smoothquant", + "mit-han-lab/opt-2.7b-smoothquant", + "mit-han-lab/opt-6.7b-smoothquant", + "mit-han-lab/opt-13b-smoothquant", ], default="facebook/opt-1.3b", ) @@ -337,6 +351,12 @@ def parse_args(): type=str, default=None, ) + parser.add_argument( + "--token-model-name", + help="HF ID to create tokenizer.", + type=str, + default=None, + ) args = parser.parse_args() print("args={}".format(args)) return args @@ -344,9 +364,17 @@ def parse_args(): if __name__ == "__main__": args = parse_args() + if args.token_model_name == None: + if "smoothquant" in args.model_name: + args.token_model_name = ( + f"facebook/opt-{args.model_name.split('-')[3]}" + ) + else: + args.token_model_name = args.model_name if args.platform == PLATFORM_SHARK: shark_report = collect_shark_logits( args.model_name, + args.token_model_name, args.max_seq_len, args.recompile_shark, args.save_json, @@ -355,6 +383,9 @@ if __name__ == "__main__": print("# Summary: {}".format(json.dumps(shark_report))) else: huggingface_report = collect_huggingface_logits( - args.model_name, args.max_seq_len, args.save_json + args.model_name, + args.token_model_name, + args.max_seq_len, + args.save_json, ) print("# Summary: {}".format(json.dumps(huggingface_report)))