mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
* Switch most compile flows to use ireec.compile_file. * re-add input type to compile_str path. * Check if mlir_module exists before checking if it's a path or pyobject. * Fix some save_dir cases
202 lines
6.1 KiB
Python
202 lines
6.1 KiB
Python
import unittest
|
|
import os
|
|
import pytest
|
|
import torch
|
|
import numpy as np
|
|
from shark_opt_wrapper import OPTForCausalLMModel
|
|
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
|
from shark.shark_inference import SharkInference
|
|
from shark.shark_importer import import_with_fx, save_mlir
|
|
from transformers import AutoTokenizer, OPTForCausalLM
|
|
|
|
OPT_MODEL = "facebook/opt-1.3b"
|
|
OPT_FS_NAME = "opt-1_3b"
|
|
OPT_MODEL_66B = "facebook/opt-66b"
|
|
|
|
|
|
class OPTModuleTester:
|
|
def __init__(
|
|
self,
|
|
benchmark=False,
|
|
):
|
|
self.benchmark = benchmark
|
|
|
|
def create_and_check_module(self, dynamic, device, model_name):
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
|
opt_model = OPTForCausalLM.from_pretrained(
|
|
model_name, return_dict=False
|
|
)
|
|
opt_model.eval()
|
|
|
|
model_inputs = tokenizer(
|
|
"The meaning of life is",
|
|
padding="max_length",
|
|
max_length=30,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
inputs = (
|
|
model_inputs.data["input_ids"],
|
|
model_inputs.data["attention_mask"],
|
|
)
|
|
act_out = opt_model(
|
|
inputs[0], attention_mask=inputs[1], return_dict=False
|
|
)[0]
|
|
(
|
|
mlir_module,
|
|
func_name,
|
|
) = import_with_fx(
|
|
model=opt_model,
|
|
inputs=inputs,
|
|
is_f16=False,
|
|
model_name=OPT_FS_NAME,
|
|
)
|
|
del opt_model
|
|
opt_filename = f"./{OPT_FS_NAME}_causallm_30_torch_{device}"
|
|
mlir_path = os.path.join(opt_filename, ".mlir")
|
|
with open(mlir_path, "w") as f:
|
|
f.write(mlir_module)
|
|
print(f"Saved mlir at {mlir_path}")
|
|
del mlir_module
|
|
|
|
shark_module = SharkInference(
|
|
mlir_path,
|
|
device=device,
|
|
mlir_dialect="tm_tensor",
|
|
is_benchmark=self.benchmark,
|
|
)
|
|
|
|
shark_module.compile()
|
|
results = shark_module("forward", inputs)
|
|
print(
|
|
"SHARK logits have shape: ",
|
|
str(results[0].shape) + " : " + str(results[0]),
|
|
)
|
|
print(
|
|
"PyTorch logits have shape: "
|
|
+ str(act_out[0].shape)
|
|
+ " : "
|
|
+ str(act_out[0])
|
|
)
|
|
# exp_out = tokenizer.decode(act_out[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
# shark_out = tokenizer.decode(results[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
np.testing.assert_allclose(act_out[0].detach(), results[0])
|
|
|
|
if self.benchmark:
|
|
shark_module.shark_runner.benchmark_all_csv(
|
|
inputs,
|
|
"opt",
|
|
dynamic,
|
|
device,
|
|
"torch",
|
|
)
|
|
|
|
|
|
class OPTModuleTest(unittest.TestCase):
|
|
@pytest.fixture(autouse=True)
|
|
def configure(self, pytestconfig):
|
|
self.module_tester = OPTModuleTester(self)
|
|
self.module_tester.save_mlir = False
|
|
self.module_tester.save_vmfb = False
|
|
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
|
|
|
def test_1_3b_static_cpu(self):
|
|
dynamic = False
|
|
device = "cpu"
|
|
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
|
|
|
def test_1_3b_dynamic_cpu(self):
|
|
dynamic = True
|
|
device = "cpu"
|
|
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
|
|
|
@pytest.mark.skipif(
|
|
check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
|
)
|
|
def test_1_3b_static_cuda(self):
|
|
dynamic = False
|
|
device = "cuda"
|
|
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
|
|
|
@pytest.mark.skipif(
|
|
check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
|
)
|
|
def test_1_3b_dynamic_cuda(self):
|
|
dynamic = True
|
|
device = "cuda"
|
|
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
|
|
|
@pytest.mark.skipif(
|
|
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
|
)
|
|
def test_1_3b_static_vulkan(self):
|
|
dynamic = False
|
|
device = "vulkan"
|
|
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
|
|
|
@pytest.mark.skipif(
|
|
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
|
)
|
|
def test_1_3b_dynamic_vulkan(self):
|
|
dynamic = True
|
|
device = "vulkan"
|
|
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
|
|
|
# def test_66b_static_cpu(self):
|
|
# dynamic = False
|
|
# device = "cpu"
|
|
# self.module_tester.create_and_check_module(
|
|
# dynamic, device, OPT_MODEL_66B
|
|
# )
|
|
|
|
# def test_66b_dynamic_cpu(self):
|
|
# dynamic = True
|
|
# device = "cpu"
|
|
# self.module_tester.create_and_check_module(
|
|
# dynamic, device, OPT_MODEL_66B
|
|
# )
|
|
|
|
# @pytest.mark.skipif(
|
|
# check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
|
# )
|
|
# def test_66b_static_cuda(self):
|
|
# dynamic = False
|
|
# device = "cuda"
|
|
# self.module_tester.create_and_check_module(
|
|
# dynamic, device, OPT_MODEL_66B
|
|
# )
|
|
|
|
# @pytest.mark.skipif(
|
|
# check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
|
# )
|
|
# def test_66b_dynamic_cuda(self):
|
|
# dynamic = True
|
|
# device = "cuda"
|
|
# self.module_tester.create_and_check_module(
|
|
# dynamic, device, OPT_MODEL_66B
|
|
# )
|
|
|
|
# @pytest.mark.skipif(
|
|
# check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
|
# )
|
|
# def test_66b_static_vulkan(self):
|
|
# dynamic = False
|
|
# device = "vulkan"
|
|
# self.module_tester.create_and_check_module(
|
|
# dynamic, device, OPT_MODEL_66B
|
|
# )
|
|
|
|
# @pytest.mark.skipif(
|
|
# check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
|
# )
|
|
# def test_66b_dynamic_vulkan(self):
|
|
# dynamic = True
|
|
# device = "vulkan"
|
|
# self.module_tester.create_and_check_module(
|
|
# dynamic, device, OPT_MODEL_66B
|
|
# )
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|