Merge pull request #7 from pashu123/mulinps

Add support to pass multiple inputs.
This commit is contained in:
powderluv
2022-03-22 10:10:50 -07:00
committed by GitHub
3 changed files with 10 additions and 8 deletions

View File

@@ -26,20 +26,23 @@ def get_iree_compiled_module(module, device: str):
str(module), target_backends=[IREE_DEVICE_MAP[device]]
)
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)
tracer = ireert.Tracer(os.getcwd())
config = ireert.Config(IREE_DEVICE_MAP[device], tracer)
config = ireert.Config(IREE_DEVICE_MAP[device])
ctx = ireert.SystemContext(config=config)
# TODO add optimisation args.
ctx.add_vm_module(vm_module)
ModuleCompiled = ctx.modules.module["forward"]
return ModuleCompiled
return ModuleCompiled, config
def get_results(compiled_vm, input):
def get_results(compiled_vm, input, config):
"""TODO: Documentation"""
# TODO: Support returning multiple outputs.
result = compiled_vm(*input)
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm(*device_inputs)
result_numpy = np.asarray(result, dtype=result.dtype)
# TODO: Segfault if the copy of numpy array is not returned.
result_copy = np.copy(result_numpy)
# {k:v.to_host() for k, v in device_outputs.items(
return result_copy

View File

@@ -35,12 +35,12 @@ class SharkRunner:
self.torch_mlir_module = get_torch_mlir_module(
model, input, dynamic, tracing_required, from_aot
)
self.iree_compilation_module = get_iree_compiled_module(
self.iree_compilation_module, self.iree_config = get_iree_compiled_module(
self.torch_mlir_module, device
)
def forward(self, input):
return get_results(self.iree_compilation_module, input)
return get_results(self.iree_compilation_module, input, self.iree_config)
class SharkInference:

View File

@@ -116,7 +116,6 @@ def get_torch_mlir_module(
)
mb.import_module(module._c, class_annotator)
mb.module.dump()
with mb.module.context:
pm = PassManager.parse(
"torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline"