SharkInference: Fix various examples and README.md (#1903)

Follow https://github.com/nod-ai/SHARK/pull/708, remove parameter 'func_name'
for SharkInference.
This commit is contained in:
Huang Qi
2023-10-19 22:28:36 +08:00
committed by GitHub
parent 4797bb89f5
commit 66abee8e5b
12 changed files with 15 additions and 16 deletions

View File

@@ -254,7 +254,6 @@ if you want to instead incorporate this into a python script, you can pass the `
```
shark_module = SharkInference(
mlir_model,
func_name,
device=args.device,
mlir_dialect="tm_tensor",
dispatch_benchmarks="all",
@@ -297,7 +296,7 @@ torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
from shark.shark_inference import SharkInference
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
shark_module = SharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input))
@@ -320,7 +319,7 @@ mhlo_ir = r"""builtin.module {
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
shark_module.compile()
result = shark_module.forward((arg0, arg1))
```

View File

@@ -177,7 +177,7 @@ def compile_through_fx(model, inputs, mlir_loc=None):
mlir_model = str(module)
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
mlir_model, device=args.device, mlir_dialect="linalg"
)
shark_module.compile()

View File

@@ -54,7 +54,7 @@ if __name__ == "__main__":
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=False, tracing_required=False
)
shark_module = SharkInference(minilm_mlir, func_name, mlir_dialect="mhlo")
shark_module = SharkInference(minilm_mlir, mlir_dialect="mhlo")
shark_module.compile()
output_idx = 0
data_idx = 1

View File

@@ -6,7 +6,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
)
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"
mlir_model, device="cpu", mlir_dialect="tm_tensor"
)
shark_module.compile()
result = shark_module.forward(inputs)

View File

@@ -14,7 +14,7 @@ arg1 = np.ones((4, 1)).astype(np.float32)
print("Running shark on cpu backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo"
mhlo_ir, device="cpu", mlir_dialect="mhlo"
)
# Generate the random inputs and feed into the graph.
@@ -24,14 +24,14 @@ print(shark_module.forward(x))
print("Running shark on cuda backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo"
mhlo_ir, device="cuda", mlir_dialect="mhlo"
)
shark_module.compile()
print(shark_module.forward(x))
print("Running shark on vulkan backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo"
mhlo_ir, device="vulkan", mlir_dialect="mhlo"
)
shark_module.compile()
print(shark_module.forward(x))

View File

@@ -9,7 +9,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
mlir_model, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)

View File

@@ -33,7 +33,7 @@ mlir_importer = SharkImporter(
print(golden_out)
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
shark_module = SharkInference(vision_mlir, mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input,))
print("Obtained result", result)

View File

@@ -50,7 +50,7 @@ mlir_model = module
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device="cuda", mlir_dialect="linalg"
mlir_model, device="cuda", mlir_dialect="linalg"
)
shark_module.compile()

View File

@@ -360,7 +360,7 @@ mlir_importer = SharkImporter(
)
shark_module = SharkInference(
dlrm_mlir, func_name, device="vulkan", mlir_dialect="linalg"
dlrm_mlir, device="vulkan", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(input_dlrm)

View File

@@ -294,7 +294,7 @@ def test_dlrm() -> None:
)
shark_module = SharkInference(
dlrm_mlir, func_name, device="cpu", mlir_dialect="linalg"
dlrm_mlir, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)

View File

@@ -33,7 +33,7 @@ mlir_importer = SharkImporter(
tracing_required=False
)
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
shark_module = SharkInference(vision_mlir, mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input,))
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)

View File

@@ -7,7 +7,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
)
shark_module = SharkInference(
mlir_model, func_name, device="vulkan", mlir_dialect="linalg"
mlir_model, device="vulkan", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)