mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Update the README.md
Update the shark_importer and shark_inference APIs.
This commit is contained in:
31
README.md
31
README.md
@@ -139,18 +139,25 @@ pytest benchmarks
|
||||
### Shark Inference API
|
||||
|
||||
```
|
||||
from shark_runner import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
module = model class.
|
||||
(input,) = inputs to model (must be a torch-tensor)
|
||||
dynamic (boolean) = Pass the input shapes as static or dynamic.
|
||||
device = `cpu`, `gpu` or `vulkan` is supported.
|
||||
tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )
|
||||
shark_module.set_frontend("pytorch") # Use tensorflow, mhlo, linalg, tosa
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
# SharkImporter imports mlir file from the torch, tensorflow or tf-lite module.
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
torch_module,
|
||||
(input),
|
||||
frontend="torch", #tf, #tf-lite
|
||||
)
|
||||
torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
|
||||
|
||||
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input))
|
||||
|
||||
result = shark_module.forward(inputs)
|
||||
```
|
||||
|
||||
|
||||
@@ -170,11 +177,9 @@ mhlo_ir = r"""builtin.module {
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
|
||||
shark_module = SharkInference(mhlo_ir, (arg0, arg1))
|
||||
shark_module.set_frontend("mhlo")
|
||||
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward((arg0, arg1)))
|
||||
result = shark_module.forward((arg0, arg1))
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user