mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
Refactor the shark_runner shark_inference to only support mlir_modules.
1. The shark_inference is divided into shark_importer and shark_inference. 2. All the tank/pytorch tests have been updated.
This commit is contained in:
@@ -75,3 +75,17 @@ def check_device_drivers(device):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Installation info for the missing device drivers.
|
||||
def device_driver_info(device):
|
||||
if device in ["gpu", "cuda"]:
|
||||
print(
|
||||
"nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
|
||||
)
|
||||
elif device in ["metal", "vulkan"]:
|
||||
print(
|
||||
"vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
|
||||
)
|
||||
else:
|
||||
print(f"{device} is not supported.")
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.iree_utils._common import IREE_DEVICE_MAP, IREE_TARGET_MAP
|
||||
from shark.model_annotation import *
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
@@ -157,23 +156,7 @@ def export_module_to_mlir_file(module, frontend, directory: str):
|
||||
|
||||
def get_results(compiled_vm, input, config, frontend="torch"):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
device_inputs = input
|
||||
if frontend in ["torch", "pytorch"]:
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
if frontend in ["tensorflow", "tf", "tflite", "tflite-tosa"]:
|
||||
device_inputs = []
|
||||
for a in input:
|
||||
if isinstance(a, list):
|
||||
device_inputs.append(
|
||||
[
|
||||
ireert.asdevicearray(
|
||||
config.device, val, dtype=val.dtype
|
||||
)
|
||||
for val in a
|
||||
]
|
||||
)
|
||||
else:
|
||||
device_inputs.append(ireert.asdevicearray(config.device, a))
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
result = compiled_vm(*device_inputs)
|
||||
result_tensors = []
|
||||
if isinstance(result, tuple):
|
||||
|
||||
Reference in New Issue
Block a user