mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
Add smoothquant OPT to examples. (#1922)
This commit is contained in:
@@ -10,7 +10,9 @@ from typing import Iterable
|
||||
|
||||
|
||||
def create_module(model_name, tokenizer, device, args):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(
|
||||
model_name, allow_mismatched_sizes=True
|
||||
)
|
||||
opt_base_model.eval()
|
||||
opt_model = OPTForCausalLMModel(opt_base_model)
|
||||
encoded_inputs = tokenizer(
|
||||
@@ -112,6 +114,11 @@ def parse_args():
|
||||
"facebook/opt-350m",
|
||||
"facebook/opt-1.3b",
|
||||
"facebook/opt-6.7b",
|
||||
"mit-han-lab/opt-125m-smoothquant",
|
||||
"mit-han-lab/opt-1.3b-smoothquant",
|
||||
"mit-han-lab/opt-2.7b-smoothquant",
|
||||
"mit-han-lab/opt-6.7b-smoothquant",
|
||||
"mit-han-lab/opt-13b-smoothquant",
|
||||
],
|
||||
default="facebook/opt-1.3b",
|
||||
)
|
||||
@@ -162,7 +169,11 @@ def generate_tokens(
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
|
||||
if "smoothquant" in args.model_name:
|
||||
token_model_name = f"facebook/opt-{args.model_name.split('-')[3]}"
|
||||
else:
|
||||
token_model_name = args.model_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
|
||||
opt_fs_name = "-".join(
|
||||
"_".join(args.model_name.split("/")[1].split("-")).split(".")
|
||||
)
|
||||
|
||||
@@ -60,7 +60,9 @@ def import_mlir_module(
|
||||
device: str,
|
||||
max_seq_len: int,
|
||||
):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(
|
||||
model_name, ignore_mismatched_sizes=True
|
||||
)
|
||||
opt_base_model.eval()
|
||||
opt_model = OPTForCausalLMModel(opt_base_model)
|
||||
encoded_inputs = tokenizer(
|
||||
@@ -130,13 +132,14 @@ def create_vmfb_module(
|
||||
|
||||
def load_shark_model(
|
||||
model_name: str,
|
||||
token_model_name: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
plugin_path: str = [],
|
||||
) -> ModelWrapper:
|
||||
opt_fs_name = get_opt_fs_name(model_name)
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
|
||||
if recompile_shark or not os.path.isfile(vmfb_name):
|
||||
print(f"vmfb not found. compiling and saving to {vmfb_name}")
|
||||
create_vmfb_module(
|
||||
@@ -158,10 +161,12 @@ def run_shark_model(model_wrapper: ModelWrapper, tokens):
|
||||
return model_wrapper.model("forward", tokens)
|
||||
|
||||
|
||||
def load_huggingface_model(model_name: str) -> ModelWrapper:
|
||||
def load_huggingface_model(
|
||||
model_name: str, token_model_name: str
|
||||
) -> ModelWrapper:
|
||||
return ModelWrapper(
|
||||
model=OPTForCausalLM.from_pretrained(model_name),
|
||||
tokenizer=AutoTokenizer.from_pretrained(model_name),
|
||||
tokenizer=AutoTokenizer.from_pretrained(token_model_name),
|
||||
)
|
||||
|
||||
|
||||
@@ -177,11 +182,14 @@ def save_json(data, filename):
|
||||
|
||||
|
||||
def collect_huggingface_logits(
|
||||
model_name: str, max_seq_len: int, to_save_json: bool
|
||||
model_name: str,
|
||||
token_model_name: str,
|
||||
max_seq_len: int,
|
||||
to_save_json: bool,
|
||||
) -> Tuple[float, float]:
|
||||
# Load
|
||||
t0 = time.time()
|
||||
model_wrapper = load_huggingface_model(model_name)
|
||||
model_wrapper = load_huggingface_model(model_name, token_model_name)
|
||||
load_time = time.time() - t0
|
||||
print("--- Took {} seconds to load Huggingface.".format(load_time))
|
||||
load_memory_info = get_memory_info()
|
||||
@@ -225,6 +233,7 @@ def collect_huggingface_logits(
|
||||
|
||||
def collect_shark_logits(
|
||||
model_name: str,
|
||||
token_model_name: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
to_save_json: bool,
|
||||
@@ -233,7 +242,7 @@ def collect_shark_logits(
|
||||
# Load
|
||||
t0 = time.time()
|
||||
model_wrapper = load_shark_model(
|
||||
model_name, max_seq_len, recompile_shark, plugin_path
|
||||
model_name, token_model_name, max_seq_len, recompile_shark, plugin_path
|
||||
)
|
||||
load_time = time.time() - t0
|
||||
print("--- Took {} seconds to load Shark.".format(load_time))
|
||||
@@ -315,6 +324,11 @@ def parse_args():
|
||||
"facebook/opt-350m",
|
||||
"facebook/opt-1.3b",
|
||||
"facebook/opt-6.7b",
|
||||
"mit-han-lab/opt-125m-smoothquant",
|
||||
"mit-han-lab/opt-1.3b-smoothquant",
|
||||
"mit-han-lab/opt-2.7b-smoothquant",
|
||||
"mit-han-lab/opt-6.7b-smoothquant",
|
||||
"mit-han-lab/opt-13b-smoothquant",
|
||||
],
|
||||
default="facebook/opt-1.3b",
|
||||
)
|
||||
@@ -337,6 +351,12 @@ def parse_args():
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token-model-name",
|
||||
help="HF ID to create tokenizer.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("args={}".format(args))
|
||||
return args
|
||||
@@ -344,9 +364,17 @@ def parse_args():
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
if args.token_model_name == None:
|
||||
if "smoothquant" in args.model_name:
|
||||
args.token_model_name = (
|
||||
f"facebook/opt-{args.model_name.split('-')[3]}"
|
||||
)
|
||||
else:
|
||||
args.token_model_name = args.model_name
|
||||
if args.platform == PLATFORM_SHARK:
|
||||
shark_report = collect_shark_logits(
|
||||
args.model_name,
|
||||
args.token_model_name,
|
||||
args.max_seq_len,
|
||||
args.recompile_shark,
|
||||
args.save_json,
|
||||
@@ -355,6 +383,9 @@ if __name__ == "__main__":
|
||||
print("# Summary: {}".format(json.dumps(shark_report)))
|
||||
else:
|
||||
huggingface_report = collect_huggingface_logits(
|
||||
args.model_name, args.max_seq_len, args.save_json
|
||||
args.model_name,
|
||||
args.token_model_name,
|
||||
args.max_seq_len,
|
||||
args.save_json,
|
||||
)
|
||||
print("# Summary: {}".format(json.dumps(huggingface_report)))
|
||||
|
||||
Reference in New Issue
Block a user