mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
SharkImporter for tflite without forward and compile (#159)
This commit is contained in:
@@ -16,7 +16,7 @@ import urllib.request
|
||||
import csv
|
||||
import argparse
|
||||
import iree.compiler.tflite as ireec_tflite
|
||||
from shark.iree_utils import IREE_TARGET_MAP
|
||||
from shark.iree_utils._common import IREE_TARGET_MAP
|
||||
|
||||
|
||||
class SharkTank:
|
||||
@@ -46,7 +46,7 @@ class SharkTank:
|
||||
if self.tflite_model_list is not None:
|
||||
print("Setting up for tflite TMP_DIR")
|
||||
self.tflite_workdir = os.path.join(
|
||||
os.path.dirname(__file__), "./gen_shark_tank/tflite"
|
||||
os.path.dirname(__file__), "./gen_shark_tank"
|
||||
)
|
||||
print(f"tflite TMP_shark_tank_DIR = {self.tflite_workdir}")
|
||||
os.makedirs(self.tflite_workdir, exist_ok=True)
|
||||
@@ -72,7 +72,7 @@ class SharkTank:
|
||||
tflite_tosa_file = "/".join(
|
||||
[
|
||||
tflite_model_name_dir,
|
||||
str(tflite_model_name) + "_tosa.mlir",
|
||||
str(tflite_model_name) + "_tflite.mlir",
|
||||
]
|
||||
)
|
||||
self.binary = "/".join(
|
||||
|
||||
@@ -1,48 +1,45 @@
|
||||
# Lint as: python3
|
||||
"""SHARK Importer"""
|
||||
|
||||
import iree.compiler.tflite as iree_tflite_compile
|
||||
import iree.runtime as iree_rt
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import csv
|
||||
import tensorflow as tf
|
||||
import urllib.request
|
||||
from shark.shark_inference import SharkInference
|
||||
import iree.compiler.tflite as ireec_tflite
|
||||
from shark.iree_utils._common import IREE_TARGET_MAP
|
||||
import json
|
||||
from tank.model_utils_tflite import TFLiteModelUtil
|
||||
|
||||
|
||||
class SharkImporter:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
model_path: str = None,
|
||||
model_type: str = "tflite",
|
||||
model_source_hub: str = "tfhub",
|
||||
device: str = None,
|
||||
dynamic: bool = False,
|
||||
jit_trace: bool = False,
|
||||
benchmark_mode: bool = False,
|
||||
model_name,
|
||||
model_type: str = "torch",
|
||||
input_details=None,
|
||||
output_details=None,
|
||||
tank_url: str = None,
|
||||
model_path=None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path
|
||||
self.model_type = model_type
|
||||
self.model_source_hub = model_source_hub
|
||||
self.device = device
|
||||
self.dynamic = dynamic
|
||||
self.jit_trace = jit_trace
|
||||
self.benchmark_mode = benchmark_mode
|
||||
self.inputs = None
|
||||
self.input_details = input_details
|
||||
self.output_details = output_details
|
||||
self.tflite_saving_file = None
|
||||
self.tflite_tosa_file = None
|
||||
self.tank_url = tank_url
|
||||
self.input_details = (
|
||||
input_details # used for tflite, optional for tf/pytorch
|
||||
)
|
||||
self.output_details = (
|
||||
output_details # used for tflite, optional for tf/pytorch
|
||||
)
|
||||
self.inputs = []
|
||||
self.model_path = model_path # url to download the model
|
||||
self.raw_model_file = (
|
||||
None # local address for raw tf/tflite/pytorch model
|
||||
)
|
||||
self.mlir_file = (
|
||||
None # local address for .mlir file of tf/tflite/pytorch model
|
||||
)
|
||||
self.mlir_model = None # read of .mlir file
|
||||
self.output_tensor = (
|
||||
None # the raw tf/pytorch/tflite_output_tensor, not mlir_tensor
|
||||
)
|
||||
self.interpreter = None # could be tflite/tf/torch_interpreter in utils
|
||||
|
||||
# create tmp model file directory
|
||||
if self.model_path is None and self.model_name is None:
|
||||
@@ -51,60 +48,52 @@ class SharkImporter:
|
||||
)
|
||||
return
|
||||
|
||||
if self.model_source_hub == "tfhub":
|
||||
# compile and run tfhub tflite
|
||||
if self.model_type == "tflite":
|
||||
load_model_success = self.load_tflite_model()
|
||||
if load_model_success == False:
|
||||
print("Error, load tflite model fail")
|
||||
return
|
||||
print("Setting up for TMP_WORK_DIR")
|
||||
self.workdir = os.path.join(
|
||||
os.path.dirname(__file__), "./../gen_shark_tank"
|
||||
)
|
||||
os.makedirs(self.workdir, exist_ok=True)
|
||||
print(f"TMP_WORK_DIR = {self.workdir}")
|
||||
|
||||
if (self.input_details == None) or (
|
||||
self.output_details == None
|
||||
):
|
||||
print(
|
||||
"Setting up tflite interpreter to get model input details"
|
||||
)
|
||||
self.tflite_interpreter = tf.lite.Interpreter(
|
||||
model_path=self.tflite_saving_file
|
||||
)
|
||||
self.tflite_interpreter.allocate_tensors()
|
||||
# default input initialization
|
||||
(
|
||||
self.input_details,
|
||||
self.output_details,
|
||||
) = self.get_model_details()
|
||||
inputs = self.generate_inputs(
|
||||
self.input_details
|
||||
) # device_inputs
|
||||
self.setup_inputs(inputs)
|
||||
# compile and run tfhub tflite
|
||||
if self.model_type == "tflite":
|
||||
load_model_success = self.load_tflite_model()
|
||||
if not load_model_success:
|
||||
print("Error, load tflite model fail")
|
||||
return
|
||||
|
||||
if (self.input_details is None) or (self.output_details is None):
|
||||
print(
|
||||
"Setting up tflite interpreter to get model input details"
|
||||
)
|
||||
self.setup_interpreter()
|
||||
|
||||
inputs = self.generate_inputs(
|
||||
self.input_details
|
||||
) # device_inputs
|
||||
self.setup_inputs(inputs)
|
||||
|
||||
elif self.model_type in ["tensorflow, tf, torch, pytorch"]:
|
||||
print(self.model_type, " Not Implemented yet")
|
||||
|
||||
def load_tflite_model(self):
|
||||
print("Setting up for TMP_DIR")
|
||||
tflite_workdir = os.path.join(
|
||||
os.path.dirname(__file__), "./../gen_shark_tank/tflite"
|
||||
)
|
||||
os.makedirs(tflite_workdir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_DIR = {tflite_workdir}")
|
||||
# use model name get dir.
|
||||
tflite_model_name_dir = os.path.join(
|
||||
tflite_workdir, str(self.model_name)
|
||||
)
|
||||
# TODO Download model from google bucket to tflite_model_name_dir by tank_url
|
||||
tflite_model_name_dir = os.path.join(self.workdir, str(self.model_name))
|
||||
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
|
||||
self.tflite_saving_file = "/".join(
|
||||
self.raw_model_file = "/".join(
|
||||
[tflite_model_name_dir, str(self.model_name) + "_tflite.tflite"]
|
||||
)
|
||||
self.tflite_tosa_file = "/".join(
|
||||
[tflite_model_name_dir, str(self.model_name) + "_tosa.mlir"]
|
||||
self.mlir_file = "/".join(
|
||||
[tflite_model_name_dir, str(self.model_name) + "_tflite.mlir"]
|
||||
)
|
||||
|
||||
if os.path.exists(self.tflite_saving_file):
|
||||
if os.path.exists(self.raw_model_file):
|
||||
print(
|
||||
"Local address for tflite model file Exists: ",
|
||||
self.tflite_saving_file,
|
||||
"Local address for .tflite model file Exists: ",
|
||||
self.raw_model_file,
|
||||
)
|
||||
else:
|
||||
print("No local tflite file, Download tflite model")
|
||||
@@ -117,92 +106,109 @@ class SharkImporter:
|
||||
)
|
||||
tflite_model_list = csv.reader(open(tflite_model_list_path))
|
||||
for row in tflite_model_list:
|
||||
if str(row[0]) == self.model_name:
|
||||
if str(row[0]) == str(self.model_name):
|
||||
self.model_path = row[1]
|
||||
print("tflite_model_name", str(row[0]))
|
||||
print("tflite_model_link", self.model_path)
|
||||
if self.model_path is None:
|
||||
print("Error, No model path find in tflite_model_list.csv")
|
||||
return False
|
||||
urllib.request.urlretrieve(self.model_path, self.tflite_saving_file)
|
||||
if os.path.exists(self.tflite_tosa_file):
|
||||
print("Exists", self.tflite_tosa_file)
|
||||
urllib.request.urlretrieve(self.model_path, self.raw_model_file)
|
||||
if os.path.exists(self.mlir_file):
|
||||
print("Exists MLIR model ", self.mlir_file)
|
||||
else:
|
||||
print(
|
||||
"No tflite tosa.mlir, please use python generate_sharktank.py to download tosa model"
|
||||
)
|
||||
print("Convert tflite to tosa.mlir")
|
||||
import iree.compiler.tflite as ireec_tflite
|
||||
|
||||
ireec_tflite.compile_file(
|
||||
self.raw_model_file,
|
||||
input_type="tosa",
|
||||
save_temp_iree_input=self.mlir_file,
|
||||
target_backends=[IREE_TARGET_MAP["cpu"]],
|
||||
import_only=False,
|
||||
)
|
||||
with open(self.mlir_file) as f:
|
||||
self.mlir_model = f.read()
|
||||
return True
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
args = []
|
||||
for input in input_details:
|
||||
print(str(input["shape"]), input["dtype"].__name__)
|
||||
args.append(np.zeros(shape=input["shape"], dtype=input["dtype"]))
|
||||
return args
|
||||
|
||||
def get_model_details(self):
|
||||
def setup_interpreter(self):
|
||||
if self.model_type == "tflite":
|
||||
print("Get tflite input output details")
|
||||
self.input_details = self.tflite_interpreter.get_input_details()
|
||||
self.output_details = self.tflite_interpreter.get_output_details()
|
||||
return self.input_details, self.output_details
|
||||
self.interpreter = TFLiteModelUtil(self.raw_model_file)
|
||||
(
|
||||
self.input_details,
|
||||
self.output_details,
|
||||
) = self.interpreter.setup_tflite_interpreter()
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
self.inputs = []
|
||||
for tmp_input in input_details:
|
||||
print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
|
||||
self.inputs.append(
|
||||
np.ones(shape=tmp_input["shape"], dtype=tmp_input["dtype"])
|
||||
)
|
||||
# save inputs into json file
|
||||
tmp_json = []
|
||||
for tmp_input in input_details:
|
||||
print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
|
||||
tmp_json.append(
|
||||
np.ones(
|
||||
shape=tmp_input["shape"], dtype=tmp_input["dtype"]
|
||||
).tolist()
|
||||
)
|
||||
with open("input1.json", "w") as f:
|
||||
json.dump(tmp_json, f)
|
||||
return self.inputs
|
||||
|
||||
# def get_model_details(self):
|
||||
# if self.model_type == "tflite":
|
||||
# print("Get tflite input output details")
|
||||
# self.input_details = self.tflite_interpreter.get_input_details()
|
||||
# self.output_details = self.tflite_interpreter.get_output_details()
|
||||
# return self.input_details, self.output_details
|
||||
|
||||
def setup_inputs(self, inputs):
|
||||
print("Setting up inputs")
|
||||
self.inputs = inputs
|
||||
|
||||
def compile(self, inputs=None):
|
||||
if inputs is not None:
|
||||
self.setup_inputs(inputs)
|
||||
# preprocess model_path to get model_type and Model Source Hub
|
||||
print("Shark Importer Intialize SharkInference and Do Compile")
|
||||
if self.model_source_hub == "tfhub":
|
||||
if os.path.exists(self.tflite_tosa_file):
|
||||
print("Use", self.tflite_tosa_file, "as TOSA compile input")
|
||||
# compile and run tfhub tflite
|
||||
print("Inference tflite tosa model")
|
||||
tosa_model = []
|
||||
with open(self.tflite_tosa_file) as f:
|
||||
tosa_model = f.read()
|
||||
self.shark_module = SharkInference(
|
||||
tosa_model,
|
||||
self.inputs,
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=self.jit_trace,
|
||||
)
|
||||
self.shark_module.set_frontend("tflite-tosa")
|
||||
self.shark_module.compile()
|
||||
else:
|
||||
# compile and run tfhub tflite
|
||||
print("Inference tfhub tflite model")
|
||||
self.shark_module = SharkInference(
|
||||
self.tflite_saving_file,
|
||||
self.inputs,
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=self.jit_trace,
|
||||
)
|
||||
self.shark_module.set_frontend("tflite")
|
||||
self.shark_module.compile()
|
||||
elif self.model_source_hub == "huggingface":
|
||||
print("Inference", self.model_source_hub, " not implemented yet")
|
||||
elif self.model_source_hub == "jaxhub":
|
||||
print("Inference", self.model_source_hub, " not implemented yet")
|
||||
def get_mlir_model(self):
|
||||
return self.mlir_model
|
||||
|
||||
def forward(self, inputs=None):
|
||||
if inputs is not None:
|
||||
self.setup_inputs(inputs)
|
||||
# preprocess model_path to get model_type and Model Source Hub
|
||||
print("Shark Importer forward Model")
|
||||
if self.model_source_hub == "tfhub":
|
||||
shark_results = self.shark_module.forward(self.inputs)
|
||||
# Fix type information for unsigned cases.
|
||||
# for test compare result
|
||||
shark_results = list(shark_results)
|
||||
for i in range(len(self.output_details)):
|
||||
dtype = self.output_details[i]["dtype"]
|
||||
shark_results[i] = shark_results[i].astype(dtype)
|
||||
return shark_results
|
||||
elif self.model_source_hub == "huggingface":
|
||||
print("Inference", self.model_source_hub, " not implemented yet")
|
||||
elif self.model_source_hub == "jaxhub":
|
||||
print("Inference", self.model_source_hub, " not implemented yet")
|
||||
def get_inputs(self):
|
||||
return self.inputs
|
||||
|
||||
def get_raw_model_output(self):
|
||||
if self.model_type == "tflite":
|
||||
self.output_tensor = self.interpreter.invoke_tflite(self.inputs)
|
||||
return self.output_tensor
|
||||
|
||||
def get_model_details(self):
|
||||
return self.input_details, self.output_details
|
||||
|
||||
def get_raw_model_file(self):
|
||||
return self.raw_model_file
|
||||
|
||||
# def invoke_tflite(self, inputs):
|
||||
# print("invoke_tflite")
|
||||
# for i, input in enumerate(self.inputs):
|
||||
# self.tflite_interpreter.set_tensor(
|
||||
# self.input_details[i]["index"], input
|
||||
# )
|
||||
# self.tflite_interpreter.invoke()
|
||||
#
|
||||
# # post process tflite_result for compare with mlir_result,
|
||||
# # for tflite the output is a list of numpy.tensor
|
||||
# tflite_results = []
|
||||
# for output_detail in self.output_details:
|
||||
# tflite_results.append(
|
||||
# np.array(
|
||||
# self.tflite_interpreter.get_tensor(output_detail["index"])
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# for i in range(len(self.output_details)):
|
||||
# out_dtype = self.output_details[i]["dtype"]
|
||||
# tflite_results[i] = tflite_results[i].astype(out_dtype)
|
||||
# return tflite_results
|
||||
|
||||
49
tank/model_utils_tflite.py
Normal file
49
tank/model_utils_tflite.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TFLiteModelUtil:
|
||||
def __init__(self, raw_model_file):
|
||||
self.raw_model_file = str(raw_model_file)
|
||||
self.tflite_interpreter = None
|
||||
self.input_details = None
|
||||
self.output_details = None
|
||||
self.inputs = []
|
||||
|
||||
def setup_tflite_interpreter(self):
|
||||
self.tflite_interpreter = tf.lite.Interpreter(
|
||||
model_path=self.raw_model_file
|
||||
)
|
||||
self.tflite_interpreter.allocate_tensors()
|
||||
# default input initialization
|
||||
return self.get_model_details()
|
||||
|
||||
def get_model_details(self):
|
||||
print("Get tflite input output details")
|
||||
self.input_details = self.tflite_interpreter.get_input_details()
|
||||
self.output_details = self.tflite_interpreter.get_output_details()
|
||||
return self.input_details, self.output_details
|
||||
|
||||
def invoke_tflite(self, inputs):
|
||||
self.inputs = inputs
|
||||
print("invoke_tflite")
|
||||
for i, input in enumerate(self.inputs):
|
||||
self.tflite_interpreter.set_tensor(
|
||||
self.input_details[i]["index"], input
|
||||
)
|
||||
self.tflite_interpreter.invoke()
|
||||
|
||||
# post process tflite_result for compare with mlir_result,
|
||||
# for tflite the output is a list of numpy.tensor
|
||||
tflite_results = []
|
||||
for output_detail in self.output_details:
|
||||
tflite_results.append(
|
||||
np.array(
|
||||
self.tflite_interpreter.get_tensor(output_detail["index"])
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(len(self.output_details)):
|
||||
out_dtype = self.output_details[i]["dtype"]
|
||||
tflite_results[i] = tflite_results[i].astype(out_dtype)
|
||||
return tflite_results
|
||||
@@ -1,9 +1,11 @@
|
||||
import numpy as np
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.shark_inference import SharkInference
|
||||
import pytest
|
||||
import unittest
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
# model_path = "https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1?lite-format=tflite"
|
||||
# model_path = model_path
|
||||
|
||||
@@ -34,6 +36,21 @@ def generate_inputs(input_details):
|
||||
return args
|
||||
|
||||
|
||||
def compare_results(mlir_results, tflite_results, details):
|
||||
print("Compare mlir_results VS tflite_results: ")
|
||||
assert len(mlir_results) == len(
|
||||
tflite_results
|
||||
), "Number of results do not match"
|
||||
for i in range(len(details)):
|
||||
mlir_result = mlir_results[i]
|
||||
tflite_result = tflite_results[i]
|
||||
mlir_result = mlir_result.astype(np.single)
|
||||
tflite_result = tflite_result.astype(np.single)
|
||||
assert mlir_result.shape == tflite_result.shape, "shape doesnot match"
|
||||
max_error = np.max(np.abs(mlir_result - tflite_result))
|
||||
print("Max error (%d): %f", i, max_error)
|
||||
|
||||
|
||||
class AlbertTfliteModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -51,24 +68,40 @@ class AlbertTfliteModuleTester:
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
my_shark_importer = SharkImporter(
|
||||
model_name="albert_lite_base",
|
||||
# model_path=model_path,
|
||||
model_type="tflite",
|
||||
model_source_hub="tfhub",
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
tank_url=None,
|
||||
model_name="albert_lite_base", model_type="tflite"
|
||||
)
|
||||
# Case1: Use default inputs
|
||||
my_shark_importer.compile()
|
||||
shark_results = my_shark_importer.forward()
|
||||
|
||||
mlir_model = my_shark_importer.get_mlir_model()
|
||||
inputs = my_shark_importer.get_inputs()
|
||||
shark_module = SharkInference(
|
||||
mlir_model, inputs, device=self.device, dynamic=self.dynamic
|
||||
)
|
||||
shark_module.set_frontend("tflite-tosa")
|
||||
|
||||
# Case1: Use shark_importer default generate inputs
|
||||
shark_module.compile()
|
||||
mlir_results = shark_module.forward(inputs)
|
||||
## post process results for compare
|
||||
input_details, output_details = my_shark_importer.get_model_details()
|
||||
mlir_results = list(mlir_results)
|
||||
for i in range(len(output_details)):
|
||||
dtype = output_details[i]["dtype"]
|
||||
mlir_results[i] = mlir_results[i].astype(dtype)
|
||||
tflite_results = my_shark_importer.get_raw_model_output()
|
||||
compare_results(mlir_results, tflite_results, output_details)
|
||||
|
||||
# Case2: Use manually set inputs
|
||||
input_details, output_details = my_shark_importer.get_model_details()
|
||||
inputs = generate_inputs(input_details) # device_inputs
|
||||
my_shark_importer.compile(inputs)
|
||||
shark_results = my_shark_importer.forward(inputs)
|
||||
# print(shark_results)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, inputs, device=self.device, dynamic=self.dynamic
|
||||
)
|
||||
shark_module.set_frontend("tflite-tosa")
|
||||
shark_module.compile()
|
||||
mlir_results = shark_module.forward(inputs)
|
||||
tflite_results = my_shark_importer.get_raw_model_output()
|
||||
compare_results(mlir_results, tflite_results, output_details)
|
||||
# print(mlir_results)
|
||||
|
||||
|
||||
class AlbertTfliteModuleTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user