added ability to select gpu (#891)

Co-authored-by: Elias Joseph <elias@nod-labs.com>
This commit is contained in:
Eliasj42
2023-01-30 13:39:12 -08:00
committed by GitHub
parent fcd62513cf
commit 8111f8bf35
3 changed files with 23 additions and 5 deletions

View File

@@ -276,9 +276,19 @@ def compile_module_to_flatbuffer(
return flatbuffer_blob
def get_iree_module(flatbuffer_blob, device):
def get_iree_module(flatbuffer_blob, device, device_idx=None):
# Returns the compiled module and the configs.
config = get_iree_runtime_config(device)
if device_idx is not None:
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"]
)
# haldevice = haldriver.create_default_device()
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
)
@@ -294,20 +304,21 @@ def get_iree_compiled_module(
frontend: str = "torch",
model_config_path: str = None,
extra_args: list = [],
device_idx: int = None,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args
)
return get_iree_module(flatbuffer_blob, device)
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
def load_flatbuffer(flatbuffer_path: str, device: str):
def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None):
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
return get_iree_module(flatbuffer_blob, device)
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
def export_iree_module_to_vmfb(

View File

@@ -69,11 +69,13 @@ class SharkInference:
is_benchmark: bool = False,
dispatch_benchmark: str = None,
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
device_idx: int = None,
):
self.mlir_module = mlir_module
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.is_benchmark = is_benchmark
self.device_idx = device_idx
self.dispatch_benchmarks = (
shark_args.dispatch_benchmarks
if dispatch_benchmark is None
@@ -120,6 +122,7 @@ class SharkInference:
self.device,
self.mlir_dialect,
extra_args=extra_args,
device_idx=self.device_idx,
)
if self.dispatch_benchmarks is not None:
@@ -205,5 +208,6 @@ class SharkInference:
) = load_flatbuffer(
path,
self.device,
self.device_idx,
)
return

View File

@@ -64,11 +64,13 @@ class SharkRunner:
mlir_dialect: str = "linalg",
extra_args: list = [],
compile_vmfb: bool = True,
device_idx: int = None,
):
self.mlir_module = mlir_module
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
self.device_idx = device_idx
if check_device_drivers(self.device):
print(device_driver_info(self.device))
@@ -84,6 +86,7 @@ class SharkRunner:
self.device,
self.mlir_dialect,
extra_args=self.extra_args,
device_idx=self.device_idx,
)
def run(self, function_name, inputs: tuple, send_to_host=False):