Refactor the shark_runner shark_inference to only support mlir_modules.

1. The shark_inference is divided into shark_importer and
   shark_inference.
2. All the tank/pytorch tests have been updated.
This commit is contained in:
Prashant Kumar
2022-06-23 20:22:31 +05:30
parent 44dce561e9
commit b07377cbfd
19 changed files with 378 additions and 479 deletions

View File

@@ -75,3 +75,17 @@ def check_device_drivers(device):
return True
return False
# Installation info for the missing device drivers.
def device_driver_info(device):
if device in ["gpu", "cuda"]:
print(
"nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
)
elif device in ["metal", "vulkan"]:
print(
"vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
)
else:
print(f"{device} is not supported.")

View File

@@ -14,7 +14,6 @@
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import IREE_DEVICE_MAP, IREE_TARGET_MAP
from shark.model_annotation import *
import numpy as np
import os
@@ -157,23 +156,7 @@ def export_module_to_mlir_file(module, frontend, directory: str):
def get_results(compiled_vm, input, config, frontend="torch"):
"""Runs a .vmfb file given inputs and config and returns output."""
device_inputs = input
if frontend in ["torch", "pytorch"]:
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
if frontend in ["tensorflow", "tf", "tflite", "tflite-tosa"]:
device_inputs = []
for a in input:
if isinstance(a, list):
device_inputs.append(
[
ireert.asdevicearray(
config.device, val, dtype=val.dtype
)
for val in a
]
)
else:
device_inputs.append(ireert.asdevicearray(config.device, a))
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm(*device_inputs)
result_tensors = []
if isinstance(result, tuple):