mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
35b1e331080d98dd21a9f291694b4b8ee929a444
Shark Runner
The Shark Runner provides inference and training APIs to run deep learning models on Shark Runtime.
How to configure.
Check out the code
git clone https://github.com/NodLabs/dSHARK.git
Setup your Python VirtualEnvironment and Dependencies
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
./setup_venv.sh
# Please activate the venv after installation.
Run a demo script
python -m shark.examples.resnet50_script
Shark Inference API
from shark_runner import SharkInference
shark_module = SharkInference(
module = torch.nn.module class.
(input,) = inputs to model (must be a torch-tensor)
dynamic (boolean) = Pass the input shapes as static or dynamic.
device = `cpu`, `gpu` or `vulkan` is supported.
tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )
result = shark_module.forward(inputs)
Shark Trainer API
Work in Progress
Model Tracking (Shark Inference)
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| BERT | ✔️ (JIT) | |||
| Albert | ✔️ (JIT) | |||
| BigBird | ✔️ (AOT) | |||
| DistilBERT | ✔️ (AOT) | |||
| GPT2 | ❌ (AOT) |
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| AlexNet | ✔️ (Script) | |||
| DenseNet121 | ✔️ (Script) | |||
| MNasNet1_0 | ✔️ (Script) | |||
| MobileNetV2 | ✔️ (Script) | |||
| MobileNetV3 | ✔️ (Script) | |||
| Unet | ❌ (Script) | |||
| Resnet18 | ✔️ (Script) | ✔️ | ||
| Resnet50 | ✔️ (Script) | ✔️ | ||
| Resnext50_32x4d | ✔️ (Script) | |||
| ShuffleNet_v2 | ❌ (Script) | |||
| SqueezeNet | ❌ (Script) | |||
| EfficientNet | ✔️ (Script) | |||
| Regnet | ✔️ (Script) | |||
| Resnest | ❌ (Script) | |||
| Vision Transformer | ✔️ (Script) | |||
| VGG 16 | ✔️ (Script) | |||
| Wide Resnet | ✔️ (Script) | ✔️ | ||
| RAFT | ❌ (JIT) |
For more information refer to MODEL TRACKING SHEET
Description
Languages
Python
72.5%
C
18.2%
C++
5%
Jupyter Notebook
2.5%
CSS
0.7%
Other
1.1%