mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Refactor the shark_runner shark_inference to only support mlir_modules.
1. The shark_inference is divided into shark_importer and shark_inference. 2. All the tank/pytorch tests have been updated.
This commit is contained in:
@@ -1,49 +0,0 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TFLiteModelUtil:
|
||||
def __init__(self, raw_model_file):
|
||||
self.raw_model_file = str(raw_model_file)
|
||||
self.tflite_interpreter = None
|
||||
self.input_details = None
|
||||
self.output_details = None
|
||||
self.inputs = []
|
||||
|
||||
def setup_tflite_interpreter(self):
|
||||
self.tflite_interpreter = tf.lite.Interpreter(
|
||||
model_path=self.raw_model_file
|
||||
)
|
||||
self.tflite_interpreter.allocate_tensors()
|
||||
# default input initialization
|
||||
return self.get_model_details()
|
||||
|
||||
def get_model_details(self):
|
||||
print("Get tflite input output details")
|
||||
self.input_details = self.tflite_interpreter.get_input_details()
|
||||
self.output_details = self.tflite_interpreter.get_output_details()
|
||||
return self.input_details, self.output_details
|
||||
|
||||
def invoke_tflite(self, inputs):
|
||||
self.inputs = inputs
|
||||
print("invoke_tflite")
|
||||
for i, input in enumerate(self.inputs):
|
||||
self.tflite_interpreter.set_tensor(
|
||||
self.input_details[i]["index"], input
|
||||
)
|
||||
self.tflite_interpreter.invoke()
|
||||
|
||||
# post process tflite_result for compare with mlir_result,
|
||||
# for tflite the output is a list of numpy.tensor
|
||||
tflite_results = []
|
||||
for output_detail in self.output_details:
|
||||
tflite_results.append(
|
||||
np.array(
|
||||
self.tflite_interpreter.get_tensor(output_detail["index"])
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(len(self.output_details)):
|
||||
out_dtype = self.output_details[i]["dtype"]
|
||||
tflite_results[i] = tflite_results[i].astype(out_dtype)
|
||||
return tflite_results
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -30,12 +31,16 @@ class MiniLMModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -28,12 +29,16 @@ class AlbertModuleTester:
|
||||
model, input, act_out = get_hf_model("albert-base-v2")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
@@ -55,9 +60,6 @@ class AlbertModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
@@ -79,7 +81,6 @@ class AlbertModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
@@ -89,7 +90,6 @@ class AlbertModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -31,11 +32,16 @@ class AlexnetModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -28,12 +29,16 @@ class BertModuleTester:
|
||||
model, input, act_out = get_hf_model("bert-base-uncased")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -28,12 +29,16 @@ class DistilBertModuleTester:
|
||||
model, input, act_out = get_hf_model("distilbert-base-uncased")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -31,11 +32,16 @@ class Resnet101ModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -31,11 +32,16 @@ class Resnet18ModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -31,11 +32,16 @@ class Resnet50ModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -31,11 +32,16 @@ class SqueezenetModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -31,11 +32,16 @@ class WideResnet50ModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
Reference in New Issue
Block a user