Files
AMD-SHARK-Studio/shark/examples/shark_inference/mhlo_example.py
Prashant Kumar b07377cbfd Refactor the shark_runner shark_inference to only support mlir_modules.
1. The shark_inference is divided into shark_importer and
   shark_inference.
2. All the tank/pytorch tests have been updated.
2022-06-28 18:46:18 +05:30

35 lines
1.1 KiB
Python

from shark.shark_inference import SharkInference
import numpy as np
mhlo_ir = r"""builtin.module {
func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32>
%1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<4x4xf32>
}
}"""
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
print("Running shark on cpu backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo"
)
shark_module.compile()
print(shark_module.forward((arg0, arg1)))
print("Running shark on cuda backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo"
)
shark_module.compile()
print(shark_module.forward((arg0, arg1)))
print("Running shark on vulkan backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo"
)
shark_module.compile()
print(shark_module.forward((arg0, arg1)))