mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 14:27:58 -05:00
1fa05fc7c8d140a099a84a6e52f319a87323ccb9
SHARK
High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerators and Heterogeneous Clusters
Communication Channels
- Nod.ai SHARK Discord server: Real time discussions with the nod.ai team and other users
- GitHub issues: Feature requests, bugs etc
Installation (Linux and macOS)
Check out the code
git clone https://github.com/nod-ai/SHARK.git
Setup your Python VirtualEnvironment and Dependencies
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
./setup_venv.sh
# Please activate the venv after installation.
Run a demo script
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
Run all tests on CPU/GPU/VULKAN/Metal
pytest
# If on Linux for quicker results:
pytest --workers auto
Shark Inference API
from shark_runner import SharkInference
shark_module = SharkInference(
module = model class.
(input,) = inputs to model (must be a torch-tensor)
dynamic (boolean) = Pass the input shapes as static or dynamic.
device = `cpu`, `gpu` or `vulkan` is supported.
tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )
shark_module.set_frontend("pytorch") # Use tensorflow, mhlo, linalg, tosa
shark_module.compile()
result = shark_module.forward(inputs)
Example demonstrating running MHLO IR.
from shark.shark_inference import SharkInference
import numpy as np
mhlo_ir = r"""builtin.module {
func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32>
%1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<4x4xf32>
}
}"""
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
shark_module = SharkInference(mhlo_ir, (arg0, arg1))
shark_module.set_frontend("mhlo")
shark_module.compile()
print(shark_module.forward((arg0, arg1)))
Model Tracking (Shark Inference)
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| BERT | ✔️ (JIT) | ✔️ | ||
| Albert | ✔️ (JIT) | ✔️ | ||
| BigBird | ✔️ (AOT) | |||
| DistilBERT | ✔️ (AOT) | |||
| GPT2 | ❌ (AOT) |
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| AlexNet | ✔️ (Script) | ✔️ | ✔️ | |
| DenseNet121 | ✔️ (Script) | |||
| MNasNet1_0 | ✔️ (Script) | |||
| MobileNetV2 | ✔️ (Script) | |||
| MobileNetV3 | ✔️ (Script) | |||
| Unet | ❌ (Script) | |||
| Resnet18 | ✔️ (Script) | ✔️ | ✔️ | |
| Resnet50 | ✔️ (Script) | ✔️ | ✔️ | |
| Resnet101 | ✔️ (Script) | ✔️ | ✔️ | |
| Resnext50_32x4d | ✔️ (Script) | |||
| ShuffleNet_v2 | ❌ (Script) | |||
| SqueezeNet | ✔️ (Script) | ✔️ | ✔️ | |
| EfficientNet | ✔️ (Script) | |||
| Regnet | ✔️ (Script) | |||
| Resnest | ❌ (Script) | |||
| Vision Transformer | ✔️ (Script) | |||
| VGG 16 | ✔️ (Script) | ✔️ | ✔️ | |
| Wide Resnet | ✔️ (Script) | ✔️ | ✔️ | |
| RAFT | ❌ (JIT) |
For more information refer to MODEL TRACKING SHEET
Shark Trainer API
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---|---|---|---|---|
| BERT | ❌ | ❌ | ||
| FullyConnected | ✔️ | ✔️ |
Related Project Channels
- Upstream IREE issues: Feature requests, bugs, and other work tracking
- Upstream IREE Discord server: Daily development discussions with the core team and collaborators
- iree-discuss email list: Announcements, general and low-priority discussion
- MLIR topic within LLVM Discourse: IREE is enabled by and heavily relies on MLIR. IREE sometimes is referred to in certain MLIR discussions. Useful if you are also interested in MLIR evolution.
License
nod.ai SHARK is licensed under the terms of the Apache 2.0 License with LLVM Exceptions. See LICENSE for more information.
Description
Languages
Python
72.5%
C
18.2%
C++
5%
Jupyter Notebook
2.5%
CSS
0.7%
Other
1.1%