mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Fix deeplab&mobilebert tflite test bug (#170)
This commit is contained in:
@@ -14,21 +14,6 @@ from shark.tflite_utils import TFLitePreprocessor
|
||||
# model_path = "https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3?lite-format=tflite"
|
||||
|
||||
|
||||
def generate_inputs(input_details):
|
||||
exe_basename = os.path.basename(sys.argv[0])
|
||||
workdir = os.path.join(os.path.dirname(__file__), "../tmp", exe_basename)
|
||||
os.makedirs(workdir, exist_ok=True)
|
||||
|
||||
img_path = "https://github.com/google-coral/test_data/raw/master/bird.bmp"
|
||||
local_path = "/".join([workdir, "bird.bmp"])
|
||||
urllib.request.urlretrieve(img_path, local_path)
|
||||
|
||||
shape = input_details[0]["shape"]
|
||||
im = np.array(Image.open(local_path).resize((shape[1], shape[2])))
|
||||
args = [im.reshape(shape)]
|
||||
return args
|
||||
|
||||
|
||||
def compare_results(mlir_results, tflite_results, details):
|
||||
print("Compare mlir_results VS tflite_results: ")
|
||||
assert len(mlir_results) == len(tflite_results), "Number of results do not match"
|
||||
@@ -97,24 +82,6 @@ class DeepLabV3TfliteModuleTester:
|
||||
tflite_results = tflite_preprocessor.get_raw_model_output()
|
||||
compare_results(mlir_results, tflite_results, output_details)
|
||||
|
||||
# Case2: Use manually set inputs
|
||||
input_details, output_details = tflite_preprocessor.get_model_details()
|
||||
inputs = generate_inputs(input_details) # new inputs
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=mlir_model,
|
||||
function_name=func_name,
|
||||
device=self.device,
|
||||
mlir_dialect="tflite",
|
||||
)
|
||||
|
||||
shark_module.compile()
|
||||
mlir_results = shark_module.forward(inputs)
|
||||
## post process results for compare
|
||||
tflite_results = tflite_preprocessor.get_raw_model_output()
|
||||
compare_results(mlir_results, tflite_results, output_details)
|
||||
# print(mlir_results)
|
||||
|
||||
|
||||
class DeepLabV3TfliteModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
@@ -36,7 +36,6 @@ def compare_results(mlir_results, tflite_results, details):
|
||||
tflite_result = tflite_results[i]
|
||||
mlir_result = mlir_result.astype(np.single)
|
||||
tflite_result = tflite_result.astype(np.single)
|
||||
mlir_result = np.expand_dims(mlir_result, axis=0)
|
||||
print("mlir_result.shape", mlir_result.shape)
|
||||
print("tflite_result.shape", tflite_result.shape)
|
||||
assert mlir_result.shape == tflite_result.shape, "shape doesnot match"
|
||||
|
||||
Reference in New Issue
Block a user