mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Enable tosa.mlir as input for SharkImporter inference (#145)
* Change shark_importer to use tosa.mlir as tflite model input from local gen_shark_tank
This commit is contained in:
@@ -1,5 +1,15 @@
|
||||
# Lint as: python3
|
||||
"""SHARK Tank"""
|
||||
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
|
||||
# will generate local shark tank folder like this:
|
||||
# /SHARK
|
||||
# /gen_shark_tank
|
||||
# /tflite
|
||||
# /albert_lite_base
|
||||
# /...model_name...
|
||||
# /tf
|
||||
# /pytorch
|
||||
#
|
||||
|
||||
import os
|
||||
import urllib.request
|
||||
@@ -22,8 +32,12 @@ class SharkTank:
|
||||
|
||||
if self.torch_model_list is not None:
|
||||
print("Process torch model")
|
||||
else:
|
||||
print("Torch sharktank not implemented yet")
|
||||
if self.tf_model_list is not None:
|
||||
print("Process torch model")
|
||||
else:
|
||||
print("tf sharktank not implemented yet")
|
||||
|
||||
print("self.tflite_model_list: ", self.tflite_model_list)
|
||||
# compile and run tfhub tflite
|
||||
@@ -44,7 +58,7 @@ class SharkTank:
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
|
||||
tflite_saving_file = '/'.join([tflite_model_name_dir, str(tflite_model_name)+'_tflite.tflite'])
|
||||
iree_ir = '/'.join([tflite_model_name_dir, str(tflite_model_name)+'_tosa.mlir'])
|
||||
tflite_tosa_file = '/'.join([tflite_model_name_dir, str(tflite_model_name)+'_tosa.mlir'])
|
||||
self.binary = '/'.join([tflite_model_name_dir, str(tflite_model_name)+'_module.bytecode'])
|
||||
print("Setting up local address for tflite model file: ", tflite_saving_file)
|
||||
if os.path.exists(tflite_saving_file):
|
||||
@@ -54,14 +68,14 @@ class SharkTank:
|
||||
urllib.request.urlretrieve(str(tflite_model_link),
|
||||
tflite_saving_file)
|
||||
|
||||
if os.path.exists(iree_ir):
|
||||
print(iree_ir, "exists")
|
||||
if os.path.exists(tflite_tosa_file):
|
||||
print("Exists", tflite_tosa_file)
|
||||
else:
|
||||
print("Convert tflite to tosa.mlir")
|
||||
ireec_tflite.compile_file(
|
||||
tflite_saving_file,
|
||||
input_type="tosa",
|
||||
save_temp_iree_input=iree_ir,
|
||||
save_temp_iree_input=tflite_tosa_file,
|
||||
target_backends=[IREE_TARGET_MAP['cpu']],
|
||||
import_only=False)
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ def compile_module_to_flatbuffer(module, device, frontend, func_name,
|
||||
input_type = "mhlo"
|
||||
elif frontend in ["mhlo", "tosa"]:
|
||||
input_type = frontend
|
||||
elif frontend in ["tflite"]:
|
||||
elif frontend in ["tflite", "tflite-tosa"]:
|
||||
input_type = "tosa"
|
||||
|
||||
# Annotate the input module with the configs
|
||||
@@ -239,7 +239,7 @@ def export_iree_module_to_vmfb(module,
|
||||
|
||||
def export_module_to_mlir_file(module, frontend, directory: str):
|
||||
mlir_str = module
|
||||
if frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
if frontend in ["tensorflow", "tf", "mhlo", "tflite"]:
|
||||
mlir_str = module.decode('utf-8')
|
||||
elif frontend in ["pytorch", "torch"]:
|
||||
mlir_str = module.operation.get_asm()
|
||||
@@ -255,7 +255,7 @@ def get_results(compiled_vm, input, config, frontend="torch"):
|
||||
device_inputs = input
|
||||
if frontend in ["torch", "pytorch"]:
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
if frontend in ["tensorflow", "tf", "tflite"]:
|
||||
if frontend in ["tensorflow", "tf", "tflite", "tflite-tosa"]:
|
||||
device_inputs = []
|
||||
for a in input:
|
||||
if (isinstance(a, list)):
|
||||
|
||||
@@ -6,21 +6,29 @@ import iree.runtime as iree_rt
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import tensorflow.compat.v2 as tf
|
||||
import csv
|
||||
import tensorflow as tf
|
||||
import urllib.request
|
||||
from shark.shark_inference import SharkInference
|
||||
import iree.compiler.tflite as ireec_tflite
|
||||
from shark.iree_utils import IREE_TARGET_MAP
|
||||
|
||||
|
||||
class SharkImporter:
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
model_name: str=None,
|
||||
model_path: str=None,
|
||||
model_type: str = "tflite",
|
||||
model_source_hub: str = "tfhub",
|
||||
device: str = None,
|
||||
dynamic: bool = False,
|
||||
jit_trace: bool = False,
|
||||
benchmark_mode: bool = False):
|
||||
benchmark_mode: bool = False,
|
||||
input_details=None,
|
||||
output_details=None,
|
||||
tank_url: str = None):
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path
|
||||
self.model_type = model_type
|
||||
self.model_source_hub = model_source_hub
|
||||
@@ -29,43 +37,77 @@ class SharkImporter:
|
||||
self.jit_trace = jit_trace
|
||||
self.benchmark_mode = benchmark_mode
|
||||
self.inputs = None
|
||||
self.input_details = None
|
||||
self.output_details = None
|
||||
self.input_details = input_details
|
||||
self.output_details = output_details
|
||||
self.tflite_saving_file = None
|
||||
self.tflite_tosa_file = None
|
||||
self.tank_url = tank_url
|
||||
|
||||
# create tmp model file directory
|
||||
if self.model_path is None:
|
||||
print("Error. No model_path, Please input model path.")
|
||||
if self.model_path is None and self.model_name is None:
|
||||
print("Error. No model_path, No model name,Please input either one.")
|
||||
return
|
||||
|
||||
if self.model_source_hub == "tfhub":
|
||||
# compile and run tfhub tflite
|
||||
if self.model_type == "tflite":
|
||||
print("Setting up for TMP_DIR")
|
||||
exe_basename = os.path.basename(sys.argv[0])
|
||||
self.workdir = os.path.join(os.path.dirname(__file__), "tmp",
|
||||
exe_basename)
|
||||
print(f"TMP_DIR = {self.workdir}")
|
||||
os.makedirs(self.workdir, exist_ok=True)
|
||||
self.tflite_file = '/'.join([self.workdir, 'model.tflite'])
|
||||
print("Setting up local address for tflite model file: ",
|
||||
self.tflite_file)
|
||||
if os.path.exists(self.model_path):
|
||||
self.tflite_file = self.model_path
|
||||
else:
|
||||
print("Download tflite model")
|
||||
urllib.request.urlretrieve(self.model_path,
|
||||
self.tflite_file)
|
||||
print("Setting up tflite interpreter")
|
||||
self.tflite_interpreter = tf.lite.Interpreter(
|
||||
model_path=self.tflite_file)
|
||||
self.tflite_interpreter.allocate_tensors()
|
||||
# default input initialization
|
||||
self.input_details, self.output_details = self.get_model_details(
|
||||
)
|
||||
inputs = self.generate_inputs(
|
||||
self.input_details) # device_inputs
|
||||
load_model_success = self.load_tflite_model()
|
||||
if load_model_success == False:
|
||||
print("Error, load tflite model fail")
|
||||
return
|
||||
|
||||
if (self.input_details == None) or \
|
||||
(self.output_details == None):
|
||||
print("Setting up tflite interpreter to get model input details")
|
||||
self.tflite_interpreter = tf.lite.Interpreter(
|
||||
model_path=self.tflite_saving_file)
|
||||
self.tflite_interpreter.allocate_tensors()
|
||||
# default input initialization
|
||||
self.input_details, self.output_details = self.get_model_details(
|
||||
)
|
||||
inputs = self.generate_inputs(
|
||||
self.input_details) # device_inputs
|
||||
self.setup_inputs(inputs)
|
||||
|
||||
def load_tflite_model(self):
|
||||
print("Setting up for TMP_DIR")
|
||||
tflite_workdir = os.path.join(os.path.dirname(__file__), "./../gen_shark_tank/tflite")
|
||||
os.makedirs(tflite_workdir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_DIR = {tflite_workdir}")
|
||||
# use model name get dir.
|
||||
tflite_model_name_dir = os.path.join(tflite_workdir, str(self.model_name))
|
||||
# TODO Download model from google bucket to tflite_model_name_dir by tank_url
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
|
||||
self.tflite_saving_file = '/'.join(
|
||||
[tflite_model_name_dir, str(self.model_name) + '_tflite.tflite'])
|
||||
self.tflite_tosa_file = '/'.join(
|
||||
[tflite_model_name_dir, str(self.model_name) + '_tosa.mlir'])
|
||||
|
||||
if os.path.exists(self.tflite_saving_file):
|
||||
print("Local address for tflite model file Exists: ", self.tflite_saving_file)
|
||||
else:
|
||||
print("No local tflite file, Download tflite model")
|
||||
if self.model_path is None:
|
||||
# get model file from tflite_model_list.csv or download from gs://bucket
|
||||
print("No model_path, get from tflite_model_list.csv")
|
||||
tflite_model_list_path = os.path.join(os.path.dirname(__file__), "../tank/tflite/tflite_model_list.csv")
|
||||
tflite_model_list = csv.reader(open(tflite_model_list_path))
|
||||
for row in tflite_model_list:
|
||||
if str(row[0]) == self.model_name:
|
||||
self.model_path = row[1]
|
||||
if self.model_path is None:
|
||||
print("Error, No model path find in tflite_model_list.csv")
|
||||
return False
|
||||
urllib.request.urlretrieve(self.model_path,
|
||||
self.tflite_saving_file)
|
||||
if os.path.exists(self.tflite_tosa_file):
|
||||
print("Exists", self.tflite_tosa_file)
|
||||
else:
|
||||
print("No tflite tosa.mlir, please use python generate_sharktank.py to download tosa model")
|
||||
return True
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
args = []
|
||||
for input in input_details:
|
||||
@@ -90,15 +132,30 @@ class SharkImporter:
|
||||
# preprocess model_path to get model_type and Model Source Hub
|
||||
print("Shark Importer Intialize SharkInference and Do Compile")
|
||||
if self.model_source_hub == "tfhub":
|
||||
# compile and run tfhub tflite
|
||||
print("Inference tfhub model")
|
||||
self.shark_module = SharkInference(self.tflite_file,
|
||||
self.inputs,
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=self.jit_trace)
|
||||
self.shark_module.set_frontend("tflite")
|
||||
self.shark_module.compile()
|
||||
if os.path.exists(self.tflite_tosa_file):
|
||||
print("Use", self.tflite_tosa_file, "as TOSA compile input")
|
||||
# compile and run tfhub tflite
|
||||
print("Inference tflite tosa model")
|
||||
tosa_model = []
|
||||
with open(self.tflite_tosa_file) as f:
|
||||
tosa_model = f.read()
|
||||
self.shark_module = SharkInference(tosa_model,
|
||||
self.inputs,
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=self.jit_trace)
|
||||
self.shark_module.set_frontend("tflite-tosa")
|
||||
self.shark_module.compile()
|
||||
else:
|
||||
# compile and run tfhub tflite
|
||||
print("Inference tfhub tflite model")
|
||||
self.shark_module = SharkInference(self.tflite_saving_file,
|
||||
self.inputs,
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=self.jit_trace)
|
||||
self.shark_module.set_frontend("tflite")
|
||||
self.shark_module.compile()
|
||||
elif self.model_source_hub == "huggingface":
|
||||
print("Inference", self.model_source_hub, " not implemented yet")
|
||||
elif self.model_source_hub == "jaxhub":
|
||||
|
||||
@@ -50,7 +50,7 @@ class SharkInference:
|
||||
def set_frontend(self, frontend: str):
|
||||
if frontend not in [
|
||||
"pytorch", "torch", "tensorflow", "tf", "mhlo", "linalg",
|
||||
"tosa", "tflite"
|
||||
"tosa", "tflite", "tflite-tosa"
|
||||
]:
|
||||
print_err("frontend not supported.")
|
||||
else:
|
||||
|
||||
@@ -49,7 +49,10 @@ class SharkRunner:
|
||||
self.vmfb_file = None
|
||||
func_name = "forward"
|
||||
self.device = device if device is not None else shark_args.device
|
||||
if self.frontend in ["pytorch", "torch"]:
|
||||
|
||||
if self.frontend in ["tflite-tosa"]:
|
||||
func_name = "main"
|
||||
elif self.frontend in ["pytorch", "torch"]:
|
||||
# get torch-mlir dialect
|
||||
# self.model = torch.Module
|
||||
# TODO assert
|
||||
|
||||
84
tank/tflite/albert_lite_base/albert_lite_base_tflite_test.py
Normal file
84
tank/tflite/albert_lite_base/albert_lite_base_tflite_test.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import numpy as np
|
||||
from shark.shark_importer import SharkImporter
|
||||
import pytest
|
||||
import unittest
|
||||
from shark.parser import shark_args
|
||||
|
||||
# model_path = "https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1?lite-format=tflite"
|
||||
# model_path = model_path
|
||||
|
||||
# Inputs modified to be useful albert inputs.
|
||||
def generate_inputs(input_details):
|
||||
for input in input_details:
|
||||
print(str(input["shape"]), input["dtype"].__name__)
|
||||
|
||||
args = []
|
||||
args.append(
|
||||
np.random.randint(low=0,
|
||||
high=256,
|
||||
size=input_details[0]["shape"],
|
||||
dtype=input_details[0]["dtype"]))
|
||||
args.append(
|
||||
np.ones(shape=input_details[1]["shape"],
|
||||
dtype=input_details[1]["dtype"]))
|
||||
args.append(
|
||||
np.zeros(shape=input_details[2]["shape"],
|
||||
dtype=input_details[2]["dtype"]))
|
||||
return args
|
||||
|
||||
class AlbertTfliteModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
my_shark_importer = SharkImporter(model_name="albert_lite_base",
|
||||
# model_path=model_path,
|
||||
model_type="tflite",
|
||||
model_source_hub="tfhub",
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
tank_url=None)
|
||||
# Case1: Use default inputs
|
||||
my_shark_importer.compile()
|
||||
shark_results = my_shark_importer.forward()
|
||||
# Case2: Use manually set inputs
|
||||
input_details, output_details = my_shark_importer.get_model_details()
|
||||
inputs = generate_inputs(input_details) # device_inputs
|
||||
my_shark_importer.compile(inputs)
|
||||
shark_results = my_shark_importer.forward(inputs)
|
||||
# print(shark_results)
|
||||
|
||||
|
||||
class AlbertTfliteModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = AlbertTfliteModuleTester(self)
|
||||
self.module_tester.save_mlir = self.save_mlir
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
if __name__ == '__main__':
|
||||
# module_tester = AlbertTfliteModuleTester()
|
||||
# module_tester.create_and_check_module()
|
||||
unittest.main()
|
||||
4
tank/tflite/conftest.py
Normal file
4
tank/tflite/conftest.py
Normal file
@@ -0,0 +1,4 @@
|
||||
def pytest_addoption(parser):
|
||||
# Attaches SHARK command-line arguments to the pytest machinery.
|
||||
parser.addoption("--save_mlir", default="False", help="Pass option to save input MLIR module to /tmp/ directory.")
|
||||
parser.addoption("--save_vmfb", default="False", help="Pass option to save input MLIR module to /tmp/ directory.")
|
||||
Reference in New Issue
Block a user