mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
* Add --additional_runtime_args option and use in OPT example. Fix the func name. (#1838) Co-authored-by: Sungsoon Cho <sungsoon.cho@gmail.com>
32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
from shark.shark_inference import SharkInference
|
|
import numpy as np
|
|
|
|
mhlo_ir = r"""builtin.module {
|
|
func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
|
|
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32>
|
|
%1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
|
|
return %1 : tensor<4x4xf32>
|
|
}
|
|
}"""
|
|
|
|
arg0 = np.ones((1, 4)).astype(np.float32)
|
|
arg1 = np.ones((4, 1)).astype(np.float32)
|
|
|
|
print("Running shark on cpu backend")
|
|
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
|
|
|
# Generate the random inputs and feed into the graph.
|
|
x = shark_module.generate_random_inputs()
|
|
shark_module.compile()
|
|
print(shark_module.forward(x))
|
|
|
|
print("Running shark on cuda backend")
|
|
shark_module = SharkInference(mhlo_ir, device="cuda", mlir_dialect="mhlo")
|
|
shark_module.compile()
|
|
print(shark_module.forward(x))
|
|
|
|
print("Running shark on vulkan backend")
|
|
shark_module = SharkInference(mhlo_ir, device="vulkan", mlir_dialect="mhlo")
|
|
shark_module.compile()
|
|
print(shark_module.forward(x))
|