Files
AMD-SHARK-Studio/tank/examples/opt/shark_hf_base_opt.py
Ean Garvey caf6cc5d8f Switch most compile flows to use ireec.compile_file. (#1863)
* Switch most compile flows to use ireec.compile_file.

* re-add input type to compile_str path.

* Check if mlir_module exists before checking if it's a path or pyobject.

* Fix some save_dir cases
2023-10-06 23:04:43 -05:00

50 lines
1.2 KiB
Python

import os
import torch
from transformers import AutoTokenizer, OPTForCausalLM
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx, save_mlir
from shark_opt_wrapper import OPTForCausalLMModel
model_name = "facebook/opt-1.3b"
base_model = OPTForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = OPTForCausalLMModel(base_model)
prompt = "What is the meaning of life?"
model_inputs = tokenizer(prompt, return_tensors="pt")
inputs = (
model_inputs["input_ids"],
model_inputs["attention_mask"],
)
(
mlir_module,
func_name,
) = import_with_fx(
model=model,
inputs=inputs,
is_f16=False,
)
mlir_module = save_mlir(
mlir_module,
model_name=model_name.split("/")[1],
frontend="torch",
mlir_dialect="linalg",
)
shark_module = SharkInference(
mlir_module,
device="cpu-sync",
mlir_dialect="tm_tensor",
)
shark_module.compile()
# Generated logits.
logits = shark_module("forward", inputs=inputs)
print("SHARK module returns logits:")
print(logits[0])
hf_logits = base_model.forward(inputs[0], inputs[1], return_dict=False)[0]
print("PyTorch baseline returns logits:")
print(hf_logits)