Files
AMD-SHARK-Studio/README.md
2022-04-13 13:19:18 -07:00

4.5 KiB

Shark Runner

The Shark Runner provides inference and training APIs to run deep learning models on Shark Runtime.

How to configure.

Check out the code

git clone https://github.com/NodLabs/dSHARK.git 

Setup your Python VirtualEnvironment and Dependencies

python -m venv shark_venv
source shark_venv/bin/activate
# Some older pip installs may not be able to handle the recent PyTorch deps
python -m pip install --upgrade pip
# Install latest PyTorch nightlies and build requirements.
python -m pip install -r requirements.txt

Install dependent packages

# Install latest torch-mlir release.
python -m pip install --find-links https://github.com/llvm/torch-mlir/releases torch-mlir

# Install latest SHARK release.
python -m pip install --find-links https://github.com/NodLabs/SHARK/releases  iree-compiler iree-runtime

or

# Install latest IREE release
python -m pip install --find-links https://github.com/google/iree/releases iree-compiler iree-runtime

# Install functorch 
python -m pip install ninja
python -m pip install "git+https://github.com/pytorch/functorch.git"

# Install shark_runner from the current path.
cd shark
python -m pip install .

Run a demo script

python -m  shark.examples.resnet50_script

Shark Inference API

from shark_runner import SharkInference

shark_module = SharkInference(
        module = torch.nn.module class.
        (input,)  = inputs to model (must be a torch-tensor)
        dynamic (boolean) = Pass the input shapes as static or dynamic.
        device = `cpu`, `gpu` or `vulkan` is supported.
        tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )

result = shark_module.forward(inputs)

Shark Trainer API

Work in Progress

Model Tracking (Shark Inference)

Hugging Face Models Torch-MLIR lowerable SHARK-CPU SHARK-CUDA SHARK-VULKAN
BERT ✔️ (JIT)
Albert ✔️ (JIT)
BigBird ✔️ (AOT)
DistilBERT ✔️ (AOT)
GPT2 (AOT)
TORCHVISION Models Torch-MLIR lowerable SHARK-CPU SHARK-CUDA SHARK-VULKAN
AlexNet ✔️ (Script)
DenseNet121 ✔️ (Script)
MNasNet1_0 ✔️ (Script)
MobileNetV2 ✔️ (Script)
MobileNetV3 ✔️ (Script)
Unet (Script)
Resnet18 ✔️ (Script) ✔️
Resnet50 ✔️ (Script) ✔️
Resnext50_32x4d ✔️ (Script)
ShuffleNet_v2 (Script)
SqueezeNet (Script)
EfficientNet ✔️ (Script)
Regnet ✔️ (Script)
Resnest (Script)
Vision Transformer ✔️ (Script)
VGG 16 ✔️ (Script)
Wide Resnet ✔️ (Script) ✔️
RAFT (JIT)

For more information refer to MODEL TRACKING SHEET