Modify the directory structure to remove conflicts.

The directory structure is modified, also disabled the aot_module
unless the shark_trainer is deployed.
This commit is contained in:
Prashant Kumar
2022-04-04 21:19:17 +05:30
committed by Prashant Kumar
parent 59485f571e
commit f9864b4ce1
20 changed files with 16 additions and 17 deletions

View File

@@ -29,7 +29,7 @@ python -m pip install --find-links https://github.com/llvm/torch-mlir/releases t
# Install latest IREE release.
python -m pip install --find-links https://github.com/google/iree/releases iree-compiler iree-runtime
# Install functorch
# Install functorch
python -m pip install ninja
python -m pip install "git+https://github.com/pytorch/functorch.git"
@@ -40,8 +40,7 @@ python -m pip install .
### Run a demo script
```shell
cd shark_runner/examples/
python resnet50_script.py
python -m shark.examples.resnet50_script
```
### Shark Inference API

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -2,7 +2,7 @@ from transformers import AutoModelForMaskedLM, BertConfig
import transformers
import torch
import torch.utils._pytree as pytree
from shark_runner import SharkInference
from shark.shark_runner import SharkInference
pytree._register_pytree_node(
transformers.modeling_outputs.MaskedLMOutput,

View File

@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from shark_runner import SharkInference, SharkTrainer
from shark.shark_runner import SharkInference, SharkTrainer
class NeuralNet(nn.Module):

View File

@@ -1,6 +1,6 @@
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark_runner import SharkInference
from shark.shark_runner import SharkInference
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")

View File

@@ -4,7 +4,7 @@ import torch
import torchvision.models as models
from torchvision import transforms
import sys
from shark_runner import SharkInference
from shark.shark_runner import SharkInference
################################## Preprocessing inputs and model ############

View File

@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torch_mlir_utils import get_torch_mlir_module
from iree_utils import get_results, get_iree_compiled_module
from functorch_utils import AOTModule
from shark.torch_mlir_utils import get_torch_mlir_module
from shark.iree_utils import get_results, get_iree_compiled_module
# from functorch_utils import AOTModule
class SharkRunner:
@@ -62,13 +62,13 @@ class SharkInference:
self.input = input
self.from_aot = from_aot
if from_aot:
aot_module = AOTModule(
model, input, custom_inference_fn=custom_inference_fn
)
aot_module.generate_inference_graph()
self.model = aot_module.forward_graph
self.input = aot_module.forward_inputs
# if from_aot:
# aot_module = AOTModule(
# model, input, custom_inference_fn=custom_inference_fn
# )
# aot_module.generate_inference_graph()
# self.model = aot_module.forward_graph
# self.input = aot_module.forward_inputs
self.shark_runner = SharkRunner(
self.model, self.input, dynamic, device, jit_trace, from_aot