mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
2117c1ab39a0aad9b700b7b148d0504e28c32664
On fresh pip venv, this commit will help install the packages onto the newly create venv.
Shark Runner
The Shark Runner provides inference and training APIs to run deep learning models on Shark Runtime.
How to configure.
Check out the code
git clone https://github.com/NodLabs/dSHARK.git
Setup your Python VirtualEnvironment and Dependencies
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
./setup_venv.sh
# Please activate the venv after installation.
Run a demo script
python -m shark.examples.resnet50_script --device="cpu" # Use gpu | vulkan
Shark Inference API
from shark_runner import SharkInference
shark_module = SharkInference(
module = torch.nn.module class.
(input,) = inputs to model (must be a torch-tensor)
dynamic (boolean) = Pass the input shapes as static or dynamic.
device = `cpu`, `gpu` or `vulkan` is supported.
tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )
result = shark_module.forward(inputs)
Model Tracking (Shark Inference)
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| BERT | ✔️ (JIT) | ✔️ | ||
| Albert | ✔️ (JIT) | ✔️ | ||
| BigBird | ✔️ (AOT) | |||
| DistilBERT | ✔️ (AOT) | |||
| GPT2 | ❌ (AOT) |
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| AlexNet | ✔️ (Script) | |||
| DenseNet121 | ✔️ (Script) | |||
| MNasNet1_0 | ✔️ (Script) | |||
| MobileNetV2 | ✔️ (Script) | |||
| MobileNetV3 | ✔️ (Script) | |||
| Unet | ❌ (Script) | |||
| Resnet18 | ✔️ (Script) | ✔️ | ||
| Resnet50 | ✔️ (Script) | ✔️ | ||
| Resnext50_32x4d | ✔️ (Script) | |||
| ShuffleNet_v2 | ❌ (Script) | |||
| SqueezeNet | ✔️ (Script) | ✔️ | ||
| EfficientNet | ✔️ (Script) | |||
| Regnet | ✔️ (Script) | |||
| Resnest | ❌ (Script) | |||
| Vision Transformer | ✔️ (Script) | |||
| VGG 16 | ✔️ (Script) | |||
| Wide Resnet | ✔️ (Script) | ✔️ | ||
| RAFT | ❌ (JIT) |
For more information refer to MODEL TRACKING SHEET
Shark Trainer API
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| BERT | ❌ | ❌ | ||
| FullyConnected | ✔️ | ✔️ |
Description
Languages
Python
72.5%
C
18.2%
C++
5%
Jupyter Notebook
2.5%
CSS
0.7%
Other
1.1%