Rename and refactoring.

This commit is contained in:
Prashant Kumar
2022-05-27 09:35:48 +00:00
parent 8294dc3f20
commit c69baa3b1e
13 changed files with 28 additions and 2 deletions

View File

@@ -24,7 +24,7 @@ git clone https://github.com/nod-ai/SHARK.git
### Run a demo script
```shell
python -m shark.examples.resnet50_script --device="cpu" # Use gpu | vulkan
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
```
@@ -43,16 +43,42 @@ pytest --workers auto
from shark_runner import SharkInference
shark_module = SharkInference(
module = torch.nn.module class.
module = model class.
(input,) = inputs to model (must be a torch-tensor)
dynamic (boolean) = Pass the input shapes as static or dynamic.
device = `cpu`, `gpu` or `vulkan` is supported.
tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )
shark_module.set_frontend("pytorch") # Use tensorflow, mhlo, linalg, tosa
shark_module.compile()
result = shark_module.forward(inputs)
```
### Example demonstrating running MHLO IR.
```
from shark.shark_inference import SharkInference
import numpy as np
mhlo_ir = r"""builtin.module {
func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32>
%1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<4x4xf32>
}
}"""
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
shark_module = SharkInference(mhlo_ir, (arg0, arg1))
shark_module.set_frontend("mhlo")
shark_module.compile()
print(shark_module.forward((arg0, arg1)))
```
### Model Tracking (Shark Inference)