mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
142 lines
4.2 KiB
Python
142 lines
4.2 KiB
Python
import numpy as np
|
|
from shark.shark_downloader import download_tflite_model
|
|
from shark.shark_inference import SharkInference
|
|
import pytest
|
|
import unittest
|
|
from shark.parser import shark_args
|
|
|
|
|
|
# model_path = "https://storage.googleapis.com/iree-model-artifacts/mobilebert-edgetpu-s-float.tflite"
|
|
|
|
|
|
def generate_inputs(input_details):
|
|
for input in input_details:
|
|
print(str(input["shape"]), input["dtype"].__name__)
|
|
|
|
args = []
|
|
args.append(
|
|
np.random.randint(
|
|
low=0,
|
|
high=256,
|
|
size=input_details[0]["shape"],
|
|
dtype=input_details[0]["dtype"],
|
|
)
|
|
)
|
|
args.append(
|
|
np.ones(
|
|
shape=input_details[1]["shape"], dtype=input_details[1]["dtype"]
|
|
)
|
|
)
|
|
args.append(
|
|
np.zeros(
|
|
shape=input_details[2]["shape"], dtype=input_details[2]["dtype"]
|
|
)
|
|
)
|
|
return args
|
|
|
|
|
|
def compare_results(mlir_results, tflite_results):
|
|
print("Compare mlir_results VS tflite_results: ")
|
|
assert len(mlir_results) == len(
|
|
tflite_results
|
|
), "Number of results do not match"
|
|
for i in range(len(mlir_results)):
|
|
mlir_result = mlir_results[i]
|
|
tflite_result = tflite_results[i]
|
|
mlir_result = mlir_result.astype(np.single)
|
|
tflite_result = tflite_result.astype(np.single)
|
|
print("mlir_result.shape", mlir_result.shape)
|
|
print("tflite_result.shape", tflite_result.shape)
|
|
assert mlir_result.shape == tflite_result.shape, "shape doesnot match"
|
|
max_error = np.max(np.abs(mlir_result - tflite_result))
|
|
print("Max error (%d): %f", i, max_error)
|
|
|
|
|
|
class MobilebertTfliteModuleTester:
|
|
def __init__(
|
|
self,
|
|
dynamic=False,
|
|
device="cpu",
|
|
save_mlir=False,
|
|
save_vmfb=False,
|
|
):
|
|
self.dynamic = dynamic
|
|
self.device = device
|
|
self.save_mlir = save_mlir
|
|
self.save_vmfb = save_vmfb
|
|
|
|
def create_and_check_module(self):
|
|
shark_args.save_mlir = self.save_mlir
|
|
shark_args.save_vmfb = self.save_vmfb
|
|
|
|
# Preprocess to get SharkImporter input args
|
|
mlir_model, func_name, inputs, tflite_results = download_tflite_model(
|
|
model_name="mobilebert-edgetpu-s-float"
|
|
)
|
|
|
|
# Use SharkInference to get inference result
|
|
shark_module = SharkInference(
|
|
mlir_module=mlir_model,
|
|
function_name=func_name,
|
|
device=self.device,
|
|
mlir_dialect="tflite",
|
|
)
|
|
|
|
# Case1: Use shark_importer default generate inputs
|
|
shark_module.compile()
|
|
mlir_results = shark_module.forward(inputs)
|
|
compare_results(mlir_results, tflite_results)
|
|
|
|
# Case2: Use manually set inputs
|
|
input_details = [
|
|
{
|
|
"shape": [1, 384],
|
|
"dtype": np.int32,
|
|
},
|
|
{
|
|
"shape": [1, 384],
|
|
"dtype": np.int32,
|
|
},
|
|
{
|
|
"shape": [1, 384],
|
|
"dtype": np.int32,
|
|
},
|
|
]
|
|
inputs = generate_inputs(input_details) # new inputs
|
|
|
|
shark_module = SharkInference(
|
|
mlir_module=mlir_model,
|
|
function_name=func_name,
|
|
device=self.device,
|
|
mlir_dialect="tflite",
|
|
)
|
|
shark_module.compile()
|
|
mlir_results = shark_module.forward(inputs)
|
|
compare_results(mlir_results, tflite_results)
|
|
# print(mlir_results)
|
|
|
|
|
|
class MobilebertTfliteModuleTest(unittest.TestCase):
|
|
@pytest.fixture(autouse=True)
|
|
def configure(self, pytestconfig):
|
|
self.save_mlir = pytestconfig.getoption("save_mlir")
|
|
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
|
|
|
def setUp(self):
|
|
self.module_tester = MobilebertTfliteModuleTester(self)
|
|
self.module_tester.save_mlir = self.save_mlir
|
|
|
|
def test_module_static_cpu(self):
|
|
self.module_tester.dynamic = False
|
|
self.module_tester.device = "cpu"
|
|
self.module_tester.create_and_check_module()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# module_tester = MobilebertTfliteModuleTester()
|
|
# module_tester.save_mlir = True
|
|
# module_tester.save_vmfb = True
|
|
# module_tester.create_and_check_module()
|
|
|
|
unittest.main()
|