mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
4.5 KiB
4.5 KiB
Shark Runner
The Shark Runner provides inference and training APIs to run deep learning models on Shark Runtime.
How to configure.
Check out the code
git clone https://github.com/NodLabs/dSHARK.git
Setup your Python VirtualEnvironment and Dependencies
python -m venv shark_venv
source shark_venv/bin/activate
# Some older pip installs may not be able to handle the recent PyTorch deps
python -m pip install --upgrade pip
# Install latest PyTorch nightlies and build requirements.
python -m pip install -r requirements.txt
Install dependent packages
# Install latest torch-mlir release.
python -m pip install --find-links https://github.com/llvm/torch-mlir/releases torch-mlir
# Install latest SHARK release.
python -m pip install --find-links https://github.com/NodLabs/SHARK/releases iree-compiler iree-runtime
or
# Install latest IREE release
python -m pip install --find-links https://github.com/google/iree/releases iree-compiler iree-runtime
# Install functorch
python -m pip install ninja
python -m pip install "git+https://github.com/pytorch/functorch.git"
# Install shark_runner from the current path.
cd shark
python -m pip install .
Run a demo script
python -m shark.examples.resnet50_script
Shark Inference API
from shark_runner import SharkInference
shark_module = SharkInference(
module = torch.nn.module class.
(input,) = inputs to model (must be a torch-tensor)
dynamic (boolean) = Pass the input shapes as static or dynamic.
device = `cpu`, `gpu` or `vulkan` is supported.
tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )
result = shark_module.forward(inputs)
Shark Trainer API
Work in Progress
Model Tracking (Shark Inference)
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-VULKAN |
|---|---|---|---|---|
| BERT | ✔️ (JIT) | |||
| Albert | ✔️ (JIT) | |||
| BigBird | ✔️ (AOT) | |||
| DistilBERT | ✔️ (AOT) | |||
| GPT2 | ❌ (AOT) |
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-VULKAN |
|---|---|---|---|---|
| AlexNet | ✔️ (Script) | |||
| DenseNet121 | ✔️ (Script) | |||
| MNasNet1_0 | ✔️ (Script) | |||
| MobileNetV2 | ✔️ (Script) | |||
| MobileNetV3 | ✔️ (Script) | |||
| Unet | ❌ (Script) | |||
| Resnet18 | ✔️ (Script) | ✔️ | ||
| Resnet50 | ✔️ (Script) | ✔️ | ||
| Resnext50_32x4d | ✔️ (Script) | |||
| ShuffleNet_v2 | ❌ (Script) | |||
| SqueezeNet | ❌ (Script) | |||
| EfficientNet | ✔️ (Script) | |||
| Regnet | ✔️ (Script) | |||
| Resnest | ❌ (Script) | |||
| Vision Transformer | ✔️ (Script) | |||
| VGG 16 | ✔️ (Script) | |||
| Wide Resnet | ✔️ (Script) | ✔️ | ||
| RAFT | ❌ (JIT) |
For more information refer to MODEL TRACKING SHEET