mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Merge pull request #7 from pashu123/mulinps
Add support to pass multiple inputs.
This commit is contained in:
@@ -26,20 +26,23 @@ def get_iree_compiled_module(module, device: str):
|
||||
str(module), target_backends=[IREE_DEVICE_MAP[device]]
|
||||
)
|
||||
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)
|
||||
tracer = ireert.Tracer(os.getcwd())
|
||||
config = ireert.Config(IREE_DEVICE_MAP[device], tracer)
|
||||
config = ireert.Config(IREE_DEVICE_MAP[device])
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
# TODO add optimisation args.
|
||||
ctx.add_vm_module(vm_module)
|
||||
ModuleCompiled = ctx.modules.module["forward"]
|
||||
return ModuleCompiled
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
def get_results(compiled_vm, input):
|
||||
def get_results(compiled_vm, input, config):
|
||||
"""TODO: Documentation"""
|
||||
|
||||
# TODO: Support returning multiple outputs.
|
||||
result = compiled_vm(*input)
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
result = compiled_vm(*device_inputs)
|
||||
result_numpy = np.asarray(result, dtype=result.dtype)
|
||||
# TODO: Segfault if the copy of numpy array is not returned.
|
||||
result_copy = np.copy(result_numpy)
|
||||
# {k:v.to_host() for k, v in device_outputs.items(
|
||||
return result_copy
|
||||
|
||||
|
||||
@@ -35,12 +35,12 @@ class SharkRunner:
|
||||
self.torch_mlir_module = get_torch_mlir_module(
|
||||
model, input, dynamic, tracing_required, from_aot
|
||||
)
|
||||
self.iree_compilation_module = get_iree_compiled_module(
|
||||
self.iree_compilation_module, self.iree_config = get_iree_compiled_module(
|
||||
self.torch_mlir_module, device
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return get_results(self.iree_compilation_module, input)
|
||||
return get_results(self.iree_compilation_module, input, self.iree_config)
|
||||
|
||||
|
||||
class SharkInference:
|
||||
|
||||
@@ -116,7 +116,6 @@ def get_torch_mlir_module(
|
||||
)
|
||||
mb.import_module(module._c, class_annotator)
|
||||
|
||||
mb.module.dump()
|
||||
with mb.module.context:
|
||||
pm = PassManager.parse(
|
||||
"torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline"
|
||||
|
||||
Reference in New Issue
Block a user