SharkImporter for tflite without forward and compile (#159)

This commit is contained in:
Chi_Liu
2022-06-23 22:49:35 -07:00
committed by GitHub
parent 4ae9331a77
commit 44dce561e9
4 changed files with 246 additions and 158 deletions

View File

@@ -16,7 +16,7 @@ import urllib.request
import csv
import argparse
import iree.compiler.tflite as ireec_tflite
from shark.iree_utils import IREE_TARGET_MAP
from shark.iree_utils._common import IREE_TARGET_MAP
class SharkTank:
@@ -46,7 +46,7 @@ class SharkTank:
if self.tflite_model_list is not None:
print("Setting up for tflite TMP_DIR")
self.tflite_workdir = os.path.join(
os.path.dirname(__file__), "./gen_shark_tank/tflite"
os.path.dirname(__file__), "./gen_shark_tank"
)
print(f"tflite TMP_shark_tank_DIR = {self.tflite_workdir}")
os.makedirs(self.tflite_workdir, exist_ok=True)
@@ -72,7 +72,7 @@ class SharkTank:
tflite_tosa_file = "/".join(
[
tflite_model_name_dir,
str(tflite_model_name) + "_tosa.mlir",
str(tflite_model_name) + "_tflite.mlir",
]
)
self.binary = "/".join(

View File

@@ -1,48 +1,45 @@
# Lint as: python3
"""SHARK Importer"""
import iree.compiler.tflite as iree_tflite_compile
import iree.runtime as iree_rt
import numpy as np
import os
import sys
import csv
import tensorflow as tf
import urllib.request
from shark.shark_inference import SharkInference
import iree.compiler.tflite as ireec_tflite
from shark.iree_utils._common import IREE_TARGET_MAP
import json
from tank.model_utils_tflite import TFLiteModelUtil
class SharkImporter:
def __init__(
self,
model_name: str = None,
model_path: str = None,
model_type: str = "tflite",
model_source_hub: str = "tfhub",
device: str = None,
dynamic: bool = False,
jit_trace: bool = False,
benchmark_mode: bool = False,
model_name,
model_type: str = "torch",
input_details=None,
output_details=None,
tank_url: str = None,
model_path=None,
):
self.model_name = model_name
self.model_path = model_path
self.model_type = model_type
self.model_source_hub = model_source_hub
self.device = device
self.dynamic = dynamic
self.jit_trace = jit_trace
self.benchmark_mode = benchmark_mode
self.inputs = None
self.input_details = input_details
self.output_details = output_details
self.tflite_saving_file = None
self.tflite_tosa_file = None
self.tank_url = tank_url
self.input_details = (
input_details # used for tflite, optional for tf/pytorch
)
self.output_details = (
output_details # used for tflite, optional for tf/pytorch
)
self.inputs = []
self.model_path = model_path # url to download the model
self.raw_model_file = (
None # local address for raw tf/tflite/pytorch model
)
self.mlir_file = (
None # local address for .mlir file of tf/tflite/pytorch model
)
self.mlir_model = None # read of .mlir file
self.output_tensor = (
None # the raw tf/pytorch/tflite_output_tensor, not mlir_tensor
)
self.interpreter = None # could be tflite/tf/torch_interpreter in utils
# create tmp model file directory
if self.model_path is None and self.model_name is None:
@@ -51,60 +48,52 @@ class SharkImporter:
)
return
if self.model_source_hub == "tfhub":
# compile and run tfhub tflite
if self.model_type == "tflite":
load_model_success = self.load_tflite_model()
if load_model_success == False:
print("Error, load tflite model fail")
return
print("Setting up for TMP_WORK_DIR")
self.workdir = os.path.join(
os.path.dirname(__file__), "./../gen_shark_tank"
)
os.makedirs(self.workdir, exist_ok=True)
print(f"TMP_WORK_DIR = {self.workdir}")
if (self.input_details == None) or (
self.output_details == None
):
print(
"Setting up tflite interpreter to get model input details"
)
self.tflite_interpreter = tf.lite.Interpreter(
model_path=self.tflite_saving_file
)
self.tflite_interpreter.allocate_tensors()
# default input initialization
(
self.input_details,
self.output_details,
) = self.get_model_details()
inputs = self.generate_inputs(
self.input_details
) # device_inputs
self.setup_inputs(inputs)
# compile and run tfhub tflite
if self.model_type == "tflite":
load_model_success = self.load_tflite_model()
if not load_model_success:
print("Error, load tflite model fail")
return
if (self.input_details is None) or (self.output_details is None):
print(
"Setting up tflite interpreter to get model input details"
)
self.setup_interpreter()
inputs = self.generate_inputs(
self.input_details
) # device_inputs
self.setup_inputs(inputs)
elif self.model_type in ["tensorflow, tf, torch, pytorch"]:
print(self.model_type, " Not Implemented yet")
def load_tflite_model(self):
print("Setting up for TMP_DIR")
tflite_workdir = os.path.join(
os.path.dirname(__file__), "./../gen_shark_tank/tflite"
)
os.makedirs(tflite_workdir, exist_ok=True)
print(f"TMP_TFLITE_DIR = {tflite_workdir}")
# use model name get dir.
tflite_model_name_dir = os.path.join(
tflite_workdir, str(self.model_name)
)
# TODO Download model from google bucket to tflite_model_name_dir by tank_url
tflite_model_name_dir = os.path.join(self.workdir, str(self.model_name))
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
self.tflite_saving_file = "/".join(
self.raw_model_file = "/".join(
[tflite_model_name_dir, str(self.model_name) + "_tflite.tflite"]
)
self.tflite_tosa_file = "/".join(
[tflite_model_name_dir, str(self.model_name) + "_tosa.mlir"]
self.mlir_file = "/".join(
[tflite_model_name_dir, str(self.model_name) + "_tflite.mlir"]
)
if os.path.exists(self.tflite_saving_file):
if os.path.exists(self.raw_model_file):
print(
"Local address for tflite model file Exists: ",
self.tflite_saving_file,
"Local address for .tflite model file Exists: ",
self.raw_model_file,
)
else:
print("No local tflite file, Download tflite model")
@@ -117,92 +106,109 @@ class SharkImporter:
)
tflite_model_list = csv.reader(open(tflite_model_list_path))
for row in tflite_model_list:
if str(row[0]) == self.model_name:
if str(row[0]) == str(self.model_name):
self.model_path = row[1]
print("tflite_model_name", str(row[0]))
print("tflite_model_link", self.model_path)
if self.model_path is None:
print("Error, No model path find in tflite_model_list.csv")
return False
urllib.request.urlretrieve(self.model_path, self.tflite_saving_file)
if os.path.exists(self.tflite_tosa_file):
print("Exists", self.tflite_tosa_file)
urllib.request.urlretrieve(self.model_path, self.raw_model_file)
if os.path.exists(self.mlir_file):
print("Exists MLIR model ", self.mlir_file)
else:
print(
"No tflite tosa.mlir, please use python generate_sharktank.py to download tosa model"
)
print("Convert tflite to tosa.mlir")
import iree.compiler.tflite as ireec_tflite
ireec_tflite.compile_file(
self.raw_model_file,
input_type="tosa",
save_temp_iree_input=self.mlir_file,
target_backends=[IREE_TARGET_MAP["cpu"]],
import_only=False,
)
with open(self.mlir_file) as f:
self.mlir_model = f.read()
return True
def generate_inputs(self, input_details):
args = []
for input in input_details:
print(str(input["shape"]), input["dtype"].__name__)
args.append(np.zeros(shape=input["shape"], dtype=input["dtype"]))
return args
def get_model_details(self):
def setup_interpreter(self):
if self.model_type == "tflite":
print("Get tflite input output details")
self.input_details = self.tflite_interpreter.get_input_details()
self.output_details = self.tflite_interpreter.get_output_details()
return self.input_details, self.output_details
self.interpreter = TFLiteModelUtil(self.raw_model_file)
(
self.input_details,
self.output_details,
) = self.interpreter.setup_tflite_interpreter()
def generate_inputs(self, input_details):
self.inputs = []
for tmp_input in input_details:
print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
self.inputs.append(
np.ones(shape=tmp_input["shape"], dtype=tmp_input["dtype"])
)
# save inputs into json file
tmp_json = []
for tmp_input in input_details:
print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
tmp_json.append(
np.ones(
shape=tmp_input["shape"], dtype=tmp_input["dtype"]
).tolist()
)
with open("input1.json", "w") as f:
json.dump(tmp_json, f)
return self.inputs
# def get_model_details(self):
# if self.model_type == "tflite":
# print("Get tflite input output details")
# self.input_details = self.tflite_interpreter.get_input_details()
# self.output_details = self.tflite_interpreter.get_output_details()
# return self.input_details, self.output_details
def setup_inputs(self, inputs):
print("Setting up inputs")
self.inputs = inputs
def compile(self, inputs=None):
if inputs is not None:
self.setup_inputs(inputs)
# preprocess model_path to get model_type and Model Source Hub
print("Shark Importer Intialize SharkInference and Do Compile")
if self.model_source_hub == "tfhub":
if os.path.exists(self.tflite_tosa_file):
print("Use", self.tflite_tosa_file, "as TOSA compile input")
# compile and run tfhub tflite
print("Inference tflite tosa model")
tosa_model = []
with open(self.tflite_tosa_file) as f:
tosa_model = f.read()
self.shark_module = SharkInference(
tosa_model,
self.inputs,
device=self.device,
dynamic=self.dynamic,
jit_trace=self.jit_trace,
)
self.shark_module.set_frontend("tflite-tosa")
self.shark_module.compile()
else:
# compile and run tfhub tflite
print("Inference tfhub tflite model")
self.shark_module = SharkInference(
self.tflite_saving_file,
self.inputs,
device=self.device,
dynamic=self.dynamic,
jit_trace=self.jit_trace,
)
self.shark_module.set_frontend("tflite")
self.shark_module.compile()
elif self.model_source_hub == "huggingface":
print("Inference", self.model_source_hub, " not implemented yet")
elif self.model_source_hub == "jaxhub":
print("Inference", self.model_source_hub, " not implemented yet")
def get_mlir_model(self):
return self.mlir_model
def forward(self, inputs=None):
if inputs is not None:
self.setup_inputs(inputs)
# preprocess model_path to get model_type and Model Source Hub
print("Shark Importer forward Model")
if self.model_source_hub == "tfhub":
shark_results = self.shark_module.forward(self.inputs)
# Fix type information for unsigned cases.
# for test compare result
shark_results = list(shark_results)
for i in range(len(self.output_details)):
dtype = self.output_details[i]["dtype"]
shark_results[i] = shark_results[i].astype(dtype)
return shark_results
elif self.model_source_hub == "huggingface":
print("Inference", self.model_source_hub, " not implemented yet")
elif self.model_source_hub == "jaxhub":
print("Inference", self.model_source_hub, " not implemented yet")
def get_inputs(self):
return self.inputs
def get_raw_model_output(self):
if self.model_type == "tflite":
self.output_tensor = self.interpreter.invoke_tflite(self.inputs)
return self.output_tensor
def get_model_details(self):
return self.input_details, self.output_details
def get_raw_model_file(self):
return self.raw_model_file
# def invoke_tflite(self, inputs):
# print("invoke_tflite")
# for i, input in enumerate(self.inputs):
# self.tflite_interpreter.set_tensor(
# self.input_details[i]["index"], input
# )
# self.tflite_interpreter.invoke()
#
# # post process tflite_result for compare with mlir_result,
# # for tflite the output is a list of numpy.tensor
# tflite_results = []
# for output_detail in self.output_details:
# tflite_results.append(
# np.array(
# self.tflite_interpreter.get_tensor(output_detail["index"])
# )
# )
#
# for i in range(len(self.output_details)):
# out_dtype = self.output_details[i]["dtype"]
# tflite_results[i] = tflite_results[i].astype(out_dtype)
# return tflite_results

View File

@@ -0,0 +1,49 @@
import tensorflow as tf
import numpy as np
class TFLiteModelUtil:
def __init__(self, raw_model_file):
self.raw_model_file = str(raw_model_file)
self.tflite_interpreter = None
self.input_details = None
self.output_details = None
self.inputs = []
def setup_tflite_interpreter(self):
self.tflite_interpreter = tf.lite.Interpreter(
model_path=self.raw_model_file
)
self.tflite_interpreter.allocate_tensors()
# default input initialization
return self.get_model_details()
def get_model_details(self):
print("Get tflite input output details")
self.input_details = self.tflite_interpreter.get_input_details()
self.output_details = self.tflite_interpreter.get_output_details()
return self.input_details, self.output_details
def invoke_tflite(self, inputs):
self.inputs = inputs
print("invoke_tflite")
for i, input in enumerate(self.inputs):
self.tflite_interpreter.set_tensor(
self.input_details[i]["index"], input
)
self.tflite_interpreter.invoke()
# post process tflite_result for compare with mlir_result,
# for tflite the output is a list of numpy.tensor
tflite_results = []
for output_detail in self.output_details:
tflite_results.append(
np.array(
self.tflite_interpreter.get_tensor(output_detail["index"])
)
)
for i in range(len(self.output_details)):
out_dtype = self.output_details[i]["dtype"]
tflite_results[i] = tflite_results[i].astype(out_dtype)
return tflite_results

View File

@@ -1,9 +1,11 @@
import numpy as np
from shark.shark_importer import SharkImporter
from shark.shark_inference import SharkInference
import pytest
import unittest
from shark.parser import shark_args
# model_path = "https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1?lite-format=tflite"
# model_path = model_path
@@ -34,6 +36,21 @@ def generate_inputs(input_details):
return args
def compare_results(mlir_results, tflite_results, details):
print("Compare mlir_results VS tflite_results: ")
assert len(mlir_results) == len(
tflite_results
), "Number of results do not match"
for i in range(len(details)):
mlir_result = mlir_results[i]
tflite_result = tflite_results[i]
mlir_result = mlir_result.astype(np.single)
tflite_result = tflite_result.astype(np.single)
assert mlir_result.shape == tflite_result.shape, "shape doesnot match"
max_error = np.max(np.abs(mlir_result - tflite_result))
print("Max error (%d): %f", i, max_error)
class AlbertTfliteModuleTester:
def __init__(
self,
@@ -51,24 +68,40 @@ class AlbertTfliteModuleTester:
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
my_shark_importer = SharkImporter(
model_name="albert_lite_base",
# model_path=model_path,
model_type="tflite",
model_source_hub="tfhub",
device=self.device,
dynamic=self.dynamic,
jit_trace=True,
tank_url=None,
model_name="albert_lite_base", model_type="tflite"
)
# Case1: Use default inputs
my_shark_importer.compile()
shark_results = my_shark_importer.forward()
mlir_model = my_shark_importer.get_mlir_model()
inputs = my_shark_importer.get_inputs()
shark_module = SharkInference(
mlir_model, inputs, device=self.device, dynamic=self.dynamic
)
shark_module.set_frontend("tflite-tosa")
# Case1: Use shark_importer default generate inputs
shark_module.compile()
mlir_results = shark_module.forward(inputs)
## post process results for compare
input_details, output_details = my_shark_importer.get_model_details()
mlir_results = list(mlir_results)
for i in range(len(output_details)):
dtype = output_details[i]["dtype"]
mlir_results[i] = mlir_results[i].astype(dtype)
tflite_results = my_shark_importer.get_raw_model_output()
compare_results(mlir_results, tflite_results, output_details)
# Case2: Use manually set inputs
input_details, output_details = my_shark_importer.get_model_details()
inputs = generate_inputs(input_details) # device_inputs
my_shark_importer.compile(inputs)
shark_results = my_shark_importer.forward(inputs)
# print(shark_results)
shark_module = SharkInference(
mlir_model, inputs, device=self.device, dynamic=self.dynamic
)
shark_module.set_frontend("tflite-tosa")
shark_module.compile()
mlir_results = shark_module.forward(inputs)
tflite_results = my_shark_importer.get_raw_model_output()
compare_results(mlir_results, tflite_results, output_details)
# print(mlir_results)
class AlbertTfliteModuleTest(unittest.TestCase):