Fix deeplab&mobilebert tflite test bug (#170)

This commit is contained in:
Chi_Liu
2022-06-30 21:42:14 -07:00
committed by GitHub
parent cc11a71ec8
commit 41a8cbb5b6
2 changed files with 0 additions and 34 deletions

View File

@@ -14,21 +14,6 @@ from shark.tflite_utils import TFLitePreprocessor
# model_path = "https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3?lite-format=tflite"
def generate_inputs(input_details):
exe_basename = os.path.basename(sys.argv[0])
workdir = os.path.join(os.path.dirname(__file__), "../tmp", exe_basename)
os.makedirs(workdir, exist_ok=True)
img_path = "https://github.com/google-coral/test_data/raw/master/bird.bmp"
local_path = "/".join([workdir, "bird.bmp"])
urllib.request.urlretrieve(img_path, local_path)
shape = input_details[0]["shape"]
im = np.array(Image.open(local_path).resize((shape[1], shape[2])))
args = [im.reshape(shape)]
return args
def compare_results(mlir_results, tflite_results, details):
print("Compare mlir_results VS tflite_results: ")
assert len(mlir_results) == len(tflite_results), "Number of results do not match"
@@ -97,24 +82,6 @@ class DeepLabV3TfliteModuleTester:
tflite_results = tflite_preprocessor.get_raw_model_output()
compare_results(mlir_results, tflite_results, output_details)
# Case2: Use manually set inputs
input_details, output_details = tflite_preprocessor.get_model_details()
inputs = generate_inputs(input_details) # new inputs
shark_module = SharkInference(
mlir_module=mlir_model,
function_name=func_name,
device=self.device,
mlir_dialect="tflite",
)
shark_module.compile()
mlir_results = shark_module.forward(inputs)
## post process results for compare
tflite_results = tflite_preprocessor.get_raw_model_output()
compare_results(mlir_results, tflite_results, output_details)
# print(mlir_results)
class DeepLabV3TfliteModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)

View File

@@ -36,7 +36,6 @@ def compare_results(mlir_results, tflite_results, details):
tflite_result = tflite_results[i]
mlir_result = mlir_result.astype(np.single)
tflite_result = tflite_result.astype(np.single)
mlir_result = np.expand_dims(mlir_result, axis=0)
print("mlir_result.shape", mlir_result.shape)
print("tflite_result.shape", tflite_result.shape)
assert mlir_result.shape == tflite_result.shape, "shape doesnot match"