mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-06 20:43:50 -05:00
Supported and Validated Models
PyTorch HuggingFace Models
| PyTorch Language Models | Torch-MLIR lowerable | AMDSHARK-CPU | AMDSHARK-CUDA | AMDSHARK-METAL |
|---|---|---|---|---|
| BERT | 💚 (JIT) | 💚 | 💚 | 💚 |
| Albert | 💚 (JIT) | 💚 | 💚 | 💚 |
| BigBird | 💚 (AOT) | |||
| dbmdz/ConvBERT | 💚 | 💚 | 💚 | 💚 |
| DistilBERT | 💔 (JIT) | |||
| GPT2 | 💚 | 💚 | 💚 | 💚 |
| MobileBert | 💚 (JIT) | 💚 | 💚 | 💚 |
| microsoft/beit | 💚 | 💚 | 💔 | 💔 |
| facebook/deit | 💚 | 💚 | 💔 | 💔 |
| facebook/convnext | 💚 | 💚 | 💚 | 💚 |
Torchvision Models
| TORCHVISION Models | Torch-MLIR lowerable | AMDSHARK-CPU | AMDSHARK-CUDA | AMDSHARK-METAL |
|---|---|---|---|---|
| AlexNet | 💚 (Script) | 💚 | 💚 | 💚 |
| MobileNetV2 | 💚 (Script) | 💚 | 💚 | 💚 |
| MobileNetV3 | 💚 (Script) | 💚 | 💚 | 💚 |
| Unet | 💚 (Script) | 💚 | 💚 | 💚 |
| Resnet18 | 💚 (Script) | 💚 | 💚 | 💚 |
| Resnet50 | 💚 (Script) | 💚 | 💚 | 💚 |
| Resnet101 | 💚 (Script) | 💚 | 💚 | 💚 |
| Resnext50_32x4d | 💚 (Script) | |||
| SqueezeNet | 💚 (Script) | 💚 | 💔 | 💔 |
| EfficientNet | 💚 (Script) | |||
| Regnet | 💚 (Script) | |||
| Resnest | 💔 (Script) | |||
| Vision Transformer | 💚 (Script) | 💚 | 💚 | 💚 |
| VGG 16 | 💚 (Script) | 💚 | 💚 | |
| Wide Resnet | 💚 (Script) | 💚 | 💚 | 💚 |
| RAFT | 💔 (JIT) |
For more information refer to MODEL TRACKING SHEET
Tensorflow Models (Inference)
| Hugging Face Models | tf-mhlo lowerable | AMDSHARK-CPU | AMDSHARK-CUDA | AMDSHARK-METAL |
|---|---|---|---|---|
| BERT | 💚 | 💚 | 💚 | 💚 |
| MiniLM | 💚 | 💚 | 💚 | 💚 |
| albert-base-v2 | 💚 | 💚 | 💚 | 💚 |
| DistilBERT | 💚 | 💚 | 💚 | 💚 |
| CamemBert | 💚 | 💚 | 💚 | 💚 |
| ConvBert | 💚 | 💚 | 💚 | 💚 |
| Deberta | ||||
| electra | 💚 | 💚 | 💚 | 💚 |
| funnel | ||||
| layoutlm | 💚 | 💚 | 💚 | 💚 |
| longformer | ||||
| mobile-bert | 💚 | 💚 | 💚 | 💚 |
| rembert | ||||
| tapas | ||||
| flaubert | 💔 | 💚 | 💚 | 💚 |
| roberta | 💚 | 💚 | 💚 | 💚 |
| xlm-roberta | 💚 | 💚 | 💚 | 💚 |
| mpnet | 💚 | 💚 | 💚 | 💚 |
PyTorch Training Models
| Models | Torch-MLIR lowerable | AMDSHARK-CPU | AMDSHARK-CUDA | AMDSHARK-METAL |
|---|---|---|---|---|
| BERT | 💚 | 💚 | ||
| FullyConnected | 💚 | 💚 |
JAX Models
| Models | JAX-MHLO lowerable | AMDSHARK-CPU | AMDSHARK-CUDA | AMDSHARK-METAL |
|---|---|---|---|---|
| DALL-E | 💔 | 💔 | ||
| FullyConnected | 💚 | 💚 |
TFLite Models
TFLite Models
| Models | TOSA/LinAlg | AMDSHARK-CPU | AMDSHARK-CUDA | AMDSHARK-METAL |
|---|---|---|---|---|
| BERT | 💔 | 💔 | ||
| FullyConnected | 💚 | 💚 | ||
| albert | 💚 | 💚 | ||
| asr_conformer | 💚 | 💚 | ||
| bird_classifier | 💚 | 💚 | ||
| cartoon_gan | 💚 | 💚 | ||
| craft_text | 💚 | 💚 | ||
| deeplab_v3 | 💚 | 💚 | ||
| densenet | 💚 | 💚 | ||
| east_text_detector | 💚 | 💚 | ||
| efficientnet_lite0_int8 | 💚 | 💚 | ||
| efficientnet | 💚 | 💚 | ||
| gpt2 | 💚 | 💚 | ||
| image_stylization | 💚 | 💚 | ||
| inception_v4 | 💚 | 💚 | ||
| inception_v4_uint8 | 💚 | 💚 | ||
| lightning_fp16 | 💚 | 💚 | ||
| lightning_i8 | 💚 | 💚 | ||
| lightning | 💚 | 💚 | ||
| magenta | 💚 | 💚 | ||
| midas | 💚 | 💚 | ||
| mirnet | 💚 | 💚 | ||
| mnasnet | 💚 | 💚 | ||
| mobilebert_edgetpu_s_float | 💚 | 💚 | ||
| mobilebert_edgetpu_s_quant | 💚 | 💚 | ||
| mobilebert | 💚 | 💚 | ||
| mobilebert_tf2_float | 💚 | 💚 | ||
| mobilebert_tf2_quant | 💚 | 💚 | ||
| mobilenet_ssd_quant | 💚 | 💚 | ||
| mobilenet_v1 | 💚 | 💚 | ||
| mobilenet_v1_uint8 | 💚 | 💚 | ||
| mobilenet_v2_int8 | 💚 | 💚 | ||
| mobilenet_v2 | 💚 | 💚 | ||
| mobilenet_v2_uint8 | 💚 | 💚 | ||
| mobilenet_v3-large | 💚 | 💚 | ||
| mobilenet_v3-large_uint8 | 💚 | 💚 | ||
| mobilenet_v35-int8 | 💚 | 💚 | ||
| nasnet | 💚 | 💚 | ||
| person_detect | 💚 | 💚 | ||
| posenet | 💚 | 💚 | ||
| resnet_50_int8 | 💚 | 💚 | ||
| rosetta | 💚 | 💚 | ||
| spice | 💚 | 💚 | ||
| squeezenet | 💚 | 💚 | ||
| ssd_mobilenet_v1 | 💚 | 💚 | ||
| ssd_mobilenet_v1_uint8 | 💚 | 💚 | ||
| ssd_mobilenet_v2_fpnlite | 💚 | 💚 | ||
| ssd_mobilenet_v2_fpnlite_uint8 | 💚 | 💚 | ||
| ssd_mobilenet_v2_int8 | 💚 | 💚 | ||
| ssd_mobilenet_v2 | 💚 | 💚 | ||
| ssd_spaghettinet_large | 💚 | 💚 | ||
| ssd_spaghettinet_large_uint8 | 💚 | 💚 | ||
| visual_wake_words_i8 | 💚 | 💚 |
Testing and Benchmarks
Run all model tests on CPU/GPU/VULKAN/Metal
For a list of models included in our pytest model suite, see https://github.com/nod-ai/AMD-SHARK-Studio/blob/main/tank/all_models.csv
pytest tank/test_models.py
# Models included in the pytest suite can be found listed in all_models.csv.
# If on Linux for multithreading on CPU (faster results):
pytest tank/test_models.py -n auto
Running specific tests
# Search for test cases by including a keyword that matches all or part of the test case's name;
pytest tank/test_models.py -k "keyword"
# Test cases are named uniformly by format test_module_<model_name_underscores_only>_<torch/tf>_<static/dynamic>_<device>.
# Example: Test all models on nvidia gpu:
pytest tank/test_models.py -k "cuda"
# Example: Test all tensorflow resnet models on Vulkan backend:
pytest tank/test_models.py -k "resnet and tf and vulkan"
# Exclude a test case:
pytest tank/test_models.py -k "not ..."
### Run benchmarks on AMDSHARK tank pytests and generate bench_results.csv with results.
(the following requires source installation with `IMPORTER=1 ./setup_venv.sh`)
```shell
pytest --benchmark tank/test_models.py
# Just do static GPU benchmarks for PyTorch tests:
pytest --benchmark tank/test_models.py -k "pytorch and static and cuda"
Benchmark Resnet50, MiniLM on CPU
(requires source installation with IMPORTER=1 ./setup_venv.sh)
# We suggest running the following commands as root before running benchmarks on CPU:
cat /sys/devices/system/cpu/cpu*/topology/thread_siblings_list | awk -F, '{print $2}' | sort -n | uniq | ( while read X ; do echo $X ; echo 0 > /sys/devices/system/cpu/cpu$X/online ; done )
echo 1 > /sys/devices/system/cpu/intel_pstate/no_turbo
# Benchmark canonical Resnet50 on CPU via pytest
pytest --benchmark tank/test_models.py -k "resnet50 and tf_static_cpu"
# Benchmark canonical MiniLM on CPU via pytest
pytest --benchmark tank/test_models.py -k "MiniLM and cpu"
# Benchmark MiniLM on CPU via transformer-benchmarks:
git clone --recursive https://github.com/nod-ai/transformer-benchmarks.git
cd transformer-benchmarks
./perf-ci.sh -n
# Check detail.csv for MLIR/IREE results.
To run the fine tuning example, from the root AMDSHARK directory, run:
IMPORTER=1 ./setup_venv.sh
source amdshark.venv/bin/activate
pip install jupyter tf-models-nightly tf-datasets
jupyter-notebook
if running from a google vm, you can view jupyter notebooks on your local system with:
gcloud compute ssh <YOUR_INSTANCE_DETAILS> --ssh-flag="-N -L localhost:8888:localhost:8888"