Refactor the shark_runner shark_inference to only support mlir_modules.

1. The shark_inference is divided into shark_importer and
   shark_inference.
2. All the tank/pytorch tests have been updated.
This commit is contained in:
Prashant Kumar
2022-06-23 20:22:31 +05:30
parent 44dce561e9
commit b07377cbfd
19 changed files with 378 additions and 479 deletions

View File

@@ -1,49 +0,0 @@
import tensorflow as tf
import numpy as np
class TFLiteModelUtil:
def __init__(self, raw_model_file):
self.raw_model_file = str(raw_model_file)
self.tflite_interpreter = None
self.input_details = None
self.output_details = None
self.inputs = []
def setup_tflite_interpreter(self):
self.tflite_interpreter = tf.lite.Interpreter(
model_path=self.raw_model_file
)
self.tflite_interpreter.allocate_tensors()
# default input initialization
return self.get_model_details()
def get_model_details(self):
print("Get tflite input output details")
self.input_details = self.tflite_interpreter.get_input_details()
self.output_details = self.tflite_interpreter.get_output_details()
return self.input_details, self.output_details
def invoke_tflite(self, inputs):
self.inputs = inputs
print("invoke_tflite")
for i, input in enumerate(self.inputs):
self.tflite_interpreter.set_tensor(
self.input_details[i]["index"], input
)
self.tflite_interpreter.invoke()
# post process tflite_result for compare with mlir_result,
# for tflite the output is a list of numpy.tensor
tflite_results = []
for output_detail in self.output_details:
tflite_results.append(
np.array(
self.tflite_interpreter.get_tensor(output_detail["index"])
)
)
for i in range(len(self.output_details)):
out_dtype = self.output_details[i]["dtype"]
tflite_results[i] = tflite_results[i].astype(out_dtype)
return tflite_results

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
@@ -30,12 +31,16 @@ class MiniLMModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
jit_trace=True,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
@@ -28,12 +29,16 @@ class AlbertModuleTester:
model, input, act_out = get_hf_model("albert-base-v2")
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
jit_trace=True,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))
@@ -55,9 +60,6 @@ class AlbertModuleTest(unittest.TestCase):
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(
reason="Language models currently failing for dynamic case"
)
def test_module_dynamic_cpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "cpu"
@@ -79,7 +81,6 @@ class AlbertModuleTest(unittest.TestCase):
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
@@ -89,7 +90,6 @@ class AlbertModuleTest(unittest.TestCase):
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -31,11 +32,16 @@ class AlexnetModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=False
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
@@ -28,12 +29,16 @@ class BertModuleTester:
model, input, act_out = get_hf_model("bert-base-uncased")
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
jit_trace=True,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
@@ -28,12 +29,16 @@ class DistilBertModuleTester:
model, input, act_out = get_hf_model("distilbert-base-uncased")
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
jit_trace=True,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -31,11 +32,16 @@ class Resnet101ModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=False
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -31,11 +32,16 @@ class Resnet18ModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=False
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -31,11 +32,16 @@ class Resnet50ModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=False
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -31,11 +32,16 @@ class SqueezenetModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=False
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -31,11 +32,16 @@ class WideResnet50ModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))