mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
c8aafae562acf83aac750cffb156bca624fa0443
Shark Runner
The Shark Runner provides inference and training APIs to run deep learning models on Shark Runtime.
How to configure.
Check out the code
git clone https://github.com/NodLabs/dSHARK.git
Setup your Python VirtualEnvironment and Dependencies
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
./setup_venv.sh
# Please activate the venv after installation.
Run a demo script
python -m shark.examples.resnet50_script --device="cpu" # Use gpu | vulkan
Shark Inference API
from shark_runner import SharkInference
shark_module = SharkInference(
module = torch.nn.module class.
(input,) = inputs to model (must be a torch-tensor)
dynamic (boolean) = Pass the input shapes as static or dynamic.
device = `cpu`, `gpu` or `vulkan` is supported.
tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )
result = shark_module.forward(inputs)
Model Tracking (Shark Inference)
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| BERT | ✔️ (JIT) | ✔️ | ||
| Albert | ✔️ (JIT) | ✔️ | ||
| BigBird | ✔️ (AOT) | |||
| DistilBERT | ✔️ (AOT) | |||
| GPT2 | ❌ (AOT) |
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| AlexNet | ✔️ (Script) | |||
| DenseNet121 | ✔️ (Script) | |||
| MNasNet1_0 | ✔️ (Script) | |||
| MobileNetV2 | ✔️ (Script) | |||
| MobileNetV3 | ✔️ (Script) | |||
| Unet | ❌ (Script) | |||
| Resnet18 | ✔️ (Script) | ✔️ | ||
| Resnet50 | ✔️ (Script) | ✔️ | ||
| Resnext50_32x4d | ✔️ (Script) | |||
| ShuffleNet_v2 | ❌ (Script) | |||
| SqueezeNet | ✔️ (Script) | ✔️ | ||
| EfficientNet | ✔️ (Script) | |||
| Regnet | ✔️ (Script) | |||
| Resnest | ❌ (Script) | |||
| Vision Transformer | ✔️ (Script) | |||
| VGG 16 | ✔️ (Script) | |||
| Wide Resnet | ✔️ (Script) | ✔️ | ||
| RAFT | ❌ (JIT) |
For more information refer to MODEL TRACKING SHEET
Shark Trainer API
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| BERT | ❌ | ❌ | ||
| FullyConnected | ✔️ | ✔️ |
Description
Languages
Python
72.5%
C
18.2%
C++
5%
Jupyter Notebook
2.5%
CSS
0.7%
Other
1.1%