mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Modify the directory structure to remove conflicts.
The directory structure is modified, also disabled the aot_module unless the shark_trainer is deployed.
This commit is contained in:
committed by
Prashant Kumar
parent
59485f571e
commit
f9864b4ce1
@@ -29,7 +29,7 @@ python -m pip install --find-links https://github.com/llvm/torch-mlir/releases t
|
||||
# Install latest IREE release.
|
||||
python -m pip install --find-links https://github.com/google/iree/releases iree-compiler iree-runtime
|
||||
|
||||
# Install functorch
|
||||
# Install functorch
|
||||
python -m pip install ninja
|
||||
python -m pip install "git+https://github.com/pytorch/functorch.git"
|
||||
|
||||
@@ -40,8 +40,7 @@ python -m pip install .
|
||||
|
||||
### Run a demo script
|
||||
```shell
|
||||
cd shark_runner/examples/
|
||||
python resnet50_script.py
|
||||
python -m shark.examples.resnet50_script
|
||||
```
|
||||
|
||||
### Shark Inference API
|
||||
BIN
shark/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
shark/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
shark/__pycache__/iree_utils.cpython-39.pyc
Normal file
BIN
shark/__pycache__/iree_utils.cpython-39.pyc
Normal file
Binary file not shown.
BIN
shark/__pycache__/shark_runner.cpython-39.pyc
Normal file
BIN
shark/__pycache__/shark_runner.cpython-39.pyc
Normal file
Binary file not shown.
BIN
shark/__pycache__/torch_mlir_utils.cpython-39.pyc
Normal file
BIN
shark/__pycache__/torch_mlir_utils.cpython-39.pyc
Normal file
Binary file not shown.
BIN
shark/examples/__pycache__/resnet50_script.cpython-39.pyc
Normal file
BIN
shark/examples/__pycache__/resnet50_script.cpython-39.pyc
Normal file
Binary file not shown.
@@ -2,7 +2,7 @@ from transformers import AutoModelForMaskedLM, BertConfig
|
||||
import transformers
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from shark_runner import SharkInference
|
||||
from shark.shark_runner import SharkInference
|
||||
|
||||
pytree._register_pytree_node(
|
||||
transformers.modeling_outputs.MaskedLMOutput,
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from shark_runner import SharkInference, SharkTrainer
|
||||
from shark.shark_runner import SharkInference, SharkTrainer
|
||||
|
||||
|
||||
class NeuralNet(nn.Module):
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark_runner import SharkInference
|
||||
from shark.shark_runner import SharkInference
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
from shark_runner import SharkInference
|
||||
from shark.shark_runner import SharkInference
|
||||
|
||||
|
||||
################################## Preprocessing inputs and model ############
|
||||
@@ -12,9 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from torch_mlir_utils import get_torch_mlir_module
|
||||
from iree_utils import get_results, get_iree_compiled_module
|
||||
from functorch_utils import AOTModule
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module
|
||||
from shark.iree_utils import get_results, get_iree_compiled_module
|
||||
# from functorch_utils import AOTModule
|
||||
|
||||
|
||||
class SharkRunner:
|
||||
@@ -62,13 +62,13 @@ class SharkInference:
|
||||
self.input = input
|
||||
self.from_aot = from_aot
|
||||
|
||||
if from_aot:
|
||||
aot_module = AOTModule(
|
||||
model, input, custom_inference_fn=custom_inference_fn
|
||||
)
|
||||
aot_module.generate_inference_graph()
|
||||
self.model = aot_module.forward_graph
|
||||
self.input = aot_module.forward_inputs
|
||||
# if from_aot:
|
||||
# aot_module = AOTModule(
|
||||
# model, input, custom_inference_fn=custom_inference_fn
|
||||
# )
|
||||
# aot_module.generate_inference_graph()
|
||||
# self.model = aot_module.forward_graph
|
||||
# self.input = aot_module.forward_inputs
|
||||
|
||||
self.shark_runner = SharkRunner(
|
||||
self.model, self.input, dynamic, device, jit_trace, from_aot
|
||||
Reference in New Issue
Block a user