Shark Runner

The Shark Runner provides inference and training APIs to run deep learning models on Shark Runtime.

How to configure.

Check out the code

git clone https://github.com/NodLabs/dSHARK.git 

Setup your Python VirtualEnvironment and Dependencies

# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
./setup_venv.sh
# Please activate the venv after installation.

Run a demo script

python -m  shark.examples.resnet50_script --device="cpu" # Use gpu | vulkan

Shark Inference API

from shark_runner import SharkInference

shark_module = SharkInference(
        module = torch.nn.module class.
        (input,)  = inputs to model (must be a torch-tensor)
        dynamic (boolean) = Pass the input shapes as static or dynamic.
        device = `cpu`, `gpu` or `vulkan` is supported.
        tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )

result = shark_module.forward(inputs)

Model Tracking (Shark Inference)

Hugging Face Models Torch-MLIR lowerable SHARK-CPU SHARK-CUDA SHARK-METAL
BERT ✔️ (JIT) ✔️
Albert ✔️ (JIT) ✔️
BigBird ✔️ (AOT)
DistilBERT ✔️ (AOT)
GPT2 (AOT)
TORCHVISION Models Torch-MLIR lowerable SHARK-CPU SHARK-CUDA SHARK-METAL
AlexNet ✔️ (Script)
DenseNet121 ✔️ (Script)
MNasNet1_0 ✔️ (Script)
MobileNetV2 ✔️ (Script)
MobileNetV3 ✔️ (Script)
Unet (Script)
Resnet18 ✔️ (Script) ✔️
Resnet50 ✔️ (Script) ✔️
Resnext50_32x4d ✔️ (Script)
ShuffleNet_v2 (Script)
SqueezeNet ✔️ (Script) ✔️
EfficientNet ✔️ (Script)
Regnet ✔️ (Script)
Resnest (Script)
Vision Transformer ✔️ (Script)
VGG 16 ✔️ (Script)
Wide Resnet ✔️ (Script) ✔️
RAFT (JIT)

For more information refer to MODEL TRACKING SHEET

Shark Trainer API

Models Torch-MLIR lowerable SHARK-CPU SHARK-CUDA SHARK-METAL
BERT
FullyConnected ✔️ ✔️
Description
No description provided
Readme Apache-2.0 50 MiB
Languages
Python 72.5%
C 18.2%
C++ 5%
Jupyter Notebook 2.5%
CSS 0.7%
Other 1.1%