mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
db5be15310dda3d8906206d68cc9eb7136373bac
Centralized the shark_args parser. Also added the --device flag that specifies the device on which the inference or training is to be done.
Shark Runner
The Shark Runner provides inference and training APIs to run deep learning models on Shark Runtime.
How to configure.
Check out the code
git clone https://github.com/NodLabs/dSHARK.git
Setup your Python VirtualEnvironment and Dependencies
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
./setup_venv.sh
# Please activate the venv after installation.
Run a demo script
python -m shark.examples.resnet50_script --device="cpu/gpu/vulkan"
Shark Inference API
from shark_runner import SharkInference
shark_module = SharkInference(
module = torch.nn.module class.
(input,) = inputs to model (must be a torch-tensor)
dynamic (boolean) = Pass the input shapes as static or dynamic.
device = `cpu`, `gpu` or `vulkan` is supported.
tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )
result = shark_module.forward(inputs)
Shark Trainer API
Work in Progress
Model Tracking (Shark Inference)
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| BERT | ✔️ (JIT) | ✔️ | ||
| Albert | ✔️ (JIT) | ✔️ | ||
| BigBird | ✔️ (AOT) | |||
| DistilBERT | ✔️ (AOT) | |||
| GPT2 | ❌ (AOT) |
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| AlexNet | ✔️ (Script) | |||
| DenseNet121 | ✔️ (Script) | |||
| MNasNet1_0 | ✔️ (Script) | |||
| MobileNetV2 | ✔️ (Script) | |||
| MobileNetV3 | ✔️ (Script) | |||
| Unet | ❌ (Script) | |||
| Resnet18 | ✔️ (Script) | ✔️ | ||
| Resnet50 | ✔️ (Script) | ✔️ | ||
| Resnext50_32x4d | ✔️ (Script) | |||
| ShuffleNet_v2 | ❌ (Script) | |||
| SqueezeNet | ❌ (Script) | |||
| EfficientNet | ✔️ (Script) | |||
| Regnet | ✔️ (Script) | |||
| Resnest | ❌ (Script) | |||
| Vision Transformer | ✔️ (Script) | |||
| VGG 16 | ✔️ (Script) | |||
| Wide Resnet | ✔️ (Script) | ✔️ | ||
| RAFT | ❌ (JIT) |
For more information refer to MODEL TRACKING SHEET
Description
Languages
Python
72.5%
C
18.2%
C++
5%
Jupyter Notebook
2.5%
CSS
0.7%
Other
1.1%