Enable tosa.mlir as input for SharkImporter inference (#145)

* Change shark_importer to use tosa.mlir as tflite model input from  local gen_shark_tank
This commit is contained in:
Chi_Liu
2022-06-20 23:15:14 -07:00
committed by GitHub
parent f6e9f2d571
commit af582925f2
7 changed files with 211 additions and 49 deletions

View File

@@ -1,5 +1,15 @@
# Lint as: python3
"""SHARK Tank"""
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
# will generate local shark tank folder like this:
# /SHARK
# /gen_shark_tank
# /tflite
# /albert_lite_base
# /...model_name...
# /tf
# /pytorch
#
import os
import urllib.request
@@ -22,8 +32,12 @@ class SharkTank:
if self.torch_model_list is not None:
print("Process torch model")
else:
print("Torch sharktank not implemented yet")
if self.tf_model_list is not None:
print("Process torch model")
else:
print("tf sharktank not implemented yet")
print("self.tflite_model_list: ", self.tflite_model_list)
# compile and run tfhub tflite
@@ -44,7 +58,7 @@ class SharkTank:
os.makedirs(tflite_model_name_dir, exist_ok=True)
tflite_saving_file = '/'.join([tflite_model_name_dir, str(tflite_model_name)+'_tflite.tflite'])
iree_ir = '/'.join([tflite_model_name_dir, str(tflite_model_name)+'_tosa.mlir'])
tflite_tosa_file = '/'.join([tflite_model_name_dir, str(tflite_model_name)+'_tosa.mlir'])
self.binary = '/'.join([tflite_model_name_dir, str(tflite_model_name)+'_module.bytecode'])
print("Setting up local address for tflite model file: ", tflite_saving_file)
if os.path.exists(tflite_saving_file):
@@ -54,14 +68,14 @@ class SharkTank:
urllib.request.urlretrieve(str(tflite_model_link),
tflite_saving_file)
if os.path.exists(iree_ir):
print(iree_ir, "exists")
if os.path.exists(tflite_tosa_file):
print("Exists", tflite_tosa_file)
else:
print("Convert tflite to tosa.mlir")
ireec_tflite.compile_file(
tflite_saving_file,
input_type="tosa",
save_temp_iree_input=iree_ir,
save_temp_iree_input=tflite_tosa_file,
target_backends=[IREE_TARGET_MAP['cpu']],
import_only=False)

View File

@@ -166,7 +166,7 @@ def compile_module_to_flatbuffer(module, device, frontend, func_name,
input_type = "mhlo"
elif frontend in ["mhlo", "tosa"]:
input_type = frontend
elif frontend in ["tflite"]:
elif frontend in ["tflite", "tflite-tosa"]:
input_type = "tosa"
# Annotate the input module with the configs
@@ -239,7 +239,7 @@ def export_iree_module_to_vmfb(module,
def export_module_to_mlir_file(module, frontend, directory: str):
mlir_str = module
if frontend in ["tensorflow", "tf", "mhlo"]:
if frontend in ["tensorflow", "tf", "mhlo", "tflite"]:
mlir_str = module.decode('utf-8')
elif frontend in ["pytorch", "torch"]:
mlir_str = module.operation.get_asm()
@@ -255,7 +255,7 @@ def get_results(compiled_vm, input, config, frontend="torch"):
device_inputs = input
if frontend in ["torch", "pytorch"]:
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
if frontend in ["tensorflow", "tf", "tflite"]:
if frontend in ["tensorflow", "tf", "tflite", "tflite-tosa"]:
device_inputs = []
for a in input:
if (isinstance(a, list)):

View File

@@ -6,21 +6,29 @@ import iree.runtime as iree_rt
import numpy as np
import os
import sys
import tensorflow.compat.v2 as tf
import csv
import tensorflow as tf
import urllib.request
from shark.shark_inference import SharkInference
import iree.compiler.tflite as ireec_tflite
from shark.iree_utils import IREE_TARGET_MAP
class SharkImporter:
def __init__(self,
model_path,
model_name: str=None,
model_path: str=None,
model_type: str = "tflite",
model_source_hub: str = "tfhub",
device: str = None,
dynamic: bool = False,
jit_trace: bool = False,
benchmark_mode: bool = False):
benchmark_mode: bool = False,
input_details=None,
output_details=None,
tank_url: str = None):
self.model_name = model_name
self.model_path = model_path
self.model_type = model_type
self.model_source_hub = model_source_hub
@@ -29,43 +37,77 @@ class SharkImporter:
self.jit_trace = jit_trace
self.benchmark_mode = benchmark_mode
self.inputs = None
self.input_details = None
self.output_details = None
self.input_details = input_details
self.output_details = output_details
self.tflite_saving_file = None
self.tflite_tosa_file = None
self.tank_url = tank_url
# create tmp model file directory
if self.model_path is None:
print("Error. No model_path, Please input model path.")
if self.model_path is None and self.model_name is None:
print("Error. No model_path, No model name,Please input either one.")
return
if self.model_source_hub == "tfhub":
# compile and run tfhub tflite
if self.model_type == "tflite":
print("Setting up for TMP_DIR")
exe_basename = os.path.basename(sys.argv[0])
self.workdir = os.path.join(os.path.dirname(__file__), "tmp",
exe_basename)
print(f"TMP_DIR = {self.workdir}")
os.makedirs(self.workdir, exist_ok=True)
self.tflite_file = '/'.join([self.workdir, 'model.tflite'])
print("Setting up local address for tflite model file: ",
self.tflite_file)
if os.path.exists(self.model_path):
self.tflite_file = self.model_path
else:
print("Download tflite model")
urllib.request.urlretrieve(self.model_path,
self.tflite_file)
print("Setting up tflite interpreter")
self.tflite_interpreter = tf.lite.Interpreter(
model_path=self.tflite_file)
self.tflite_interpreter.allocate_tensors()
# default input initialization
self.input_details, self.output_details = self.get_model_details(
)
inputs = self.generate_inputs(
self.input_details) # device_inputs
load_model_success = self.load_tflite_model()
if load_model_success == False:
print("Error, load tflite model fail")
return
if (self.input_details == None) or \
(self.output_details == None):
print("Setting up tflite interpreter to get model input details")
self.tflite_interpreter = tf.lite.Interpreter(
model_path=self.tflite_saving_file)
self.tflite_interpreter.allocate_tensors()
# default input initialization
self.input_details, self.output_details = self.get_model_details(
)
inputs = self.generate_inputs(
self.input_details) # device_inputs
self.setup_inputs(inputs)
def load_tflite_model(self):
print("Setting up for TMP_DIR")
tflite_workdir = os.path.join(os.path.dirname(__file__), "./../gen_shark_tank/tflite")
os.makedirs(tflite_workdir, exist_ok=True)
print(f"TMP_TFLITE_DIR = {tflite_workdir}")
# use model name get dir.
tflite_model_name_dir = os.path.join(tflite_workdir, str(self.model_name))
# TODO Download model from google bucket to tflite_model_name_dir by tank_url
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
self.tflite_saving_file = '/'.join(
[tflite_model_name_dir, str(self.model_name) + '_tflite.tflite'])
self.tflite_tosa_file = '/'.join(
[tflite_model_name_dir, str(self.model_name) + '_tosa.mlir'])
if os.path.exists(self.tflite_saving_file):
print("Local address for tflite model file Exists: ", self.tflite_saving_file)
else:
print("No local tflite file, Download tflite model")
if self.model_path is None:
# get model file from tflite_model_list.csv or download from gs://bucket
print("No model_path, get from tflite_model_list.csv")
tflite_model_list_path = os.path.join(os.path.dirname(__file__), "../tank/tflite/tflite_model_list.csv")
tflite_model_list = csv.reader(open(tflite_model_list_path))
for row in tflite_model_list:
if str(row[0]) == self.model_name:
self.model_path = row[1]
if self.model_path is None:
print("Error, No model path find in tflite_model_list.csv")
return False
urllib.request.urlretrieve(self.model_path,
self.tflite_saving_file)
if os.path.exists(self.tflite_tosa_file):
print("Exists", self.tflite_tosa_file)
else:
print("No tflite tosa.mlir, please use python generate_sharktank.py to download tosa model")
return True
def generate_inputs(self, input_details):
args = []
for input in input_details:
@@ -90,15 +132,30 @@ class SharkImporter:
# preprocess model_path to get model_type and Model Source Hub
print("Shark Importer Intialize SharkInference and Do Compile")
if self.model_source_hub == "tfhub":
# compile and run tfhub tflite
print("Inference tfhub model")
self.shark_module = SharkInference(self.tflite_file,
self.inputs,
device=self.device,
dynamic=self.dynamic,
jit_trace=self.jit_trace)
self.shark_module.set_frontend("tflite")
self.shark_module.compile()
if os.path.exists(self.tflite_tosa_file):
print("Use", self.tflite_tosa_file, "as TOSA compile input")
# compile and run tfhub tflite
print("Inference tflite tosa model")
tosa_model = []
with open(self.tflite_tosa_file) as f:
tosa_model = f.read()
self.shark_module = SharkInference(tosa_model,
self.inputs,
device=self.device,
dynamic=self.dynamic,
jit_trace=self.jit_trace)
self.shark_module.set_frontend("tflite-tosa")
self.shark_module.compile()
else:
# compile and run tfhub tflite
print("Inference tfhub tflite model")
self.shark_module = SharkInference(self.tflite_saving_file,
self.inputs,
device=self.device,
dynamic=self.dynamic,
jit_trace=self.jit_trace)
self.shark_module.set_frontend("tflite")
self.shark_module.compile()
elif self.model_source_hub == "huggingface":
print("Inference", self.model_source_hub, " not implemented yet")
elif self.model_source_hub == "jaxhub":

View File

@@ -50,7 +50,7 @@ class SharkInference:
def set_frontend(self, frontend: str):
if frontend not in [
"pytorch", "torch", "tensorflow", "tf", "mhlo", "linalg",
"tosa", "tflite"
"tosa", "tflite", "tflite-tosa"
]:
print_err("frontend not supported.")
else:

View File

@@ -49,7 +49,10 @@ class SharkRunner:
self.vmfb_file = None
func_name = "forward"
self.device = device if device is not None else shark_args.device
if self.frontend in ["pytorch", "torch"]:
if self.frontend in ["tflite-tosa"]:
func_name = "main"
elif self.frontend in ["pytorch", "torch"]:
# get torch-mlir dialect
# self.model = torch.Module
# TODO assert

View File

@@ -0,0 +1,84 @@
import numpy as np
from shark.shark_importer import SharkImporter
import pytest
import unittest
from shark.parser import shark_args
# model_path = "https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1?lite-format=tflite"
# model_path = model_path
# Inputs modified to be useful albert inputs.
def generate_inputs(input_details):
for input in input_details:
print(str(input["shape"]), input["dtype"].__name__)
args = []
args.append(
np.random.randint(low=0,
high=256,
size=input_details[0]["shape"],
dtype=input_details[0]["dtype"]))
args.append(
np.ones(shape=input_details[1]["shape"],
dtype=input_details[1]["dtype"]))
args.append(
np.zeros(shape=input_details[2]["shape"],
dtype=input_details[2]["dtype"]))
return args
class AlbertTfliteModuleTester:
def __init__(
self,
dynamic=False,
device="cpu",
save_mlir=False,
save_vmfb=False,
):
self.dynamic = dynamic
self.device = device
self.save_mlir = save_mlir
self.save_vmfb = save_vmfb
def create_and_check_module(self):
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
my_shark_importer = SharkImporter(model_name="albert_lite_base",
# model_path=model_path,
model_type="tflite",
model_source_hub="tfhub",
device=self.device,
dynamic=self.dynamic,
jit_trace=True,
tank_url=None)
# Case1: Use default inputs
my_shark_importer.compile()
shark_results = my_shark_importer.forward()
# Case2: Use manually set inputs
input_details, output_details = my_shark_importer.get_model_details()
inputs = generate_inputs(input_details) # device_inputs
my_shark_importer.compile(inputs)
shark_results = my_shark_importer.forward(inputs)
# print(shark_results)
class AlbertTfliteModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.save_mlir = pytestconfig.getoption("save_mlir")
self.save_vmfb = pytestconfig.getoption("save_vmfb")
def setUp(self):
self.module_tester = AlbertTfliteModuleTester(self)
self.module_tester.save_mlir = self.save_mlir
def test_module_static_cpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
if __name__ == '__main__':
# module_tester = AlbertTfliteModuleTester()
# module_tester.create_and_check_module()
unittest.main()

4
tank/tflite/conftest.py Normal file
View File

@@ -0,0 +1,4 @@
def pytest_addoption(parser):
# Attaches SHARK command-line arguments to the pytest machinery.
parser.addoption("--save_mlir", default="False", help="Pass option to save input MLIR module to /tmp/ directory.")
parser.addoption("--save_vmfb", default="False", help="Pass option to save input MLIR module to /tmp/ directory.")