mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 22:07:55 -05:00
Cleanup tank directory and move instructions to tank/README.md (#401)
This commit is contained in:
222
README.md
222
README.md
@@ -42,6 +42,12 @@ pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f http
|
|||||||
```
|
```
|
||||||
If you are on an Intel macOS machine you need this [workaround](https://github.com/nod-ai/SHARK/issues/102) for an upstream issue.
|
If you are on an Intel macOS machine you need this [workaround](https://github.com/nod-ai/SHARK/issues/102) for an upstream issue.
|
||||||
|
|
||||||
|
### Run shark tank model tests.
|
||||||
|
```shell
|
||||||
|
pytest tank/test_models.py
|
||||||
|
```
|
||||||
|
See tank/README.md for a more detailed walkthrough of our pytest suite and CLI.
|
||||||
|
|
||||||
### Download and run Resnet50 sample
|
### Download and run Resnet50 sample
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -114,69 +120,7 @@ pytest tank/test_models.py -k "MiniLM"
|
|||||||
<details>
|
<details>
|
||||||
<summary>Testing and Benchmarks</summary>
|
<summary>Testing and Benchmarks</summary>
|
||||||
|
|
||||||
### Run all model tests on CPU/GPU/VULKAN/Metal
|
See tank/README.md for instructions on how to run model tests and benchmarks from the SHARK tank.
|
||||||
```shell
|
|
||||||
pytest tank/test_models.py
|
|
||||||
|
|
||||||
# If on Linux for multithreading on CPU (faster results):
|
|
||||||
pytest tank/test_models.py -n auto
|
|
||||||
```
|
|
||||||
|
|
||||||
### Running specific tests
|
|
||||||
```shell
|
|
||||||
|
|
||||||
# Search for test cases by including a keyword that matches all or part of the test case's name;
|
|
||||||
pytest tank/test_models.py -k "keyword"
|
|
||||||
|
|
||||||
# Test cases are named uniformly by format test_module_<model_name_underscores_only>_<torch/tf>_<static/dynamic>_<device>.
|
|
||||||
|
|
||||||
# Example: Test all models on nvidia gpu:
|
|
||||||
pytest tank/test_models.py -k "cuda"
|
|
||||||
|
|
||||||
# Example: Test all tensorflow resnet models on Vulkan backend:
|
|
||||||
pytest tank/test_models.py -k "resnet and tf and vulkan"
|
|
||||||
|
|
||||||
# Exclude a test case:
|
|
||||||
pytest tank/test_models.py -k "not ..."
|
|
||||||
|
|
||||||
### Run benchmarks on SHARK tank pytests and generate bench_results.csv with results.
|
|
||||||
|
|
||||||
(the following requires source installation with `IMPORTER=1 ./setup_venv.sh`)
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pytest --benchmark tank/test_models.py
|
|
||||||
|
|
||||||
# Just do static GPU benchmarks for PyTorch tests:
|
|
||||||
pytest --benchmark tank/test_models.py -k "pytorch and static and cuda"
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
### Benchmark Resnet50, MiniLM on CPU
|
|
||||||
|
|
||||||
(requires source installation with `IMPORTER=1 ./setup_venv.sh`)
|
|
||||||
|
|
||||||
```shell
|
|
||||||
# We suggest running the following commands as root before running benchmarks on CPU:
|
|
||||||
|
|
||||||
cat /sys/devices/system/cpu/cpu*/topology/thread_siblings_list | awk -F, '{print $2}' | sort -n | uniq | ( while read X ; do echo $X ; echo 0 > /sys/devices/system/cpu/cpu$X/online ; done )
|
|
||||||
echo 1 > /sys/devices/system/cpu/intel_pstate/no_turbo
|
|
||||||
|
|
||||||
# Benchmark canonical Resnet50 on CPU via pytest
|
|
||||||
pytest --benchmark tank/test_models -k "resnet50 and tf_static_cpu"
|
|
||||||
|
|
||||||
# Benchmark canonical MiniLM on CPU via pytest
|
|
||||||
pytest --benchmark tank/test_models -k "MiniLM and cpu"
|
|
||||||
|
|
||||||
# Benchmark MiniLM on CPU via transformer-benchmarks:
|
|
||||||
git clone --recursive https://github.com/nod-ai/transformer-benchmarks.git
|
|
||||||
cd transformer-benchmarks
|
|
||||||
./perf-ci.sh -n
|
|
||||||
# Check detail.csv for MLIR/IREE results.
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>API Reference</summary>
|
<summary>API Reference</summary>
|
||||||
@@ -231,157 +175,7 @@ result = shark_module.forward((arg0, arg1))
|
|||||||
|
|
||||||
## Supported and Validated Models
|
## Supported and Validated Models
|
||||||
|
|
||||||
<details>
|
For a comprehensive list of the models supported in SHARK, please see tank/README.md.
|
||||||
<summary>PyTorch Models</summary>
|
|
||||||
|
|
||||||
### Huggingface PyTorch Models
|
|
||||||
|
|
||||||
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
|
||||||
|---------------------|----------------------|----------|----------|-------------|
|
|
||||||
| BERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| Albert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| BigBird | :green_heart: (AOT) | | | |
|
|
||||||
| DistilBERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| GPT2 | :broken_heart: (AOT) | | | |
|
|
||||||
| MobileBert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
|
|
||||||
### Torchvision Models
|
|
||||||
|
|
||||||
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
|
||||||
|--------------------|----------------------|----------|----------|-------------|
|
|
||||||
| AlexNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| DenseNet121 | :green_heart: (Script) | | | |
|
|
||||||
| MNasNet1_0 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| MobileNetV2 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| MobileNetV3 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| Unet | :broken_heart: (Script) | | | |
|
|
||||||
| Resnet18 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| Resnet50 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| Resnet101 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| Resnext50_32x4d | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| ShuffleNet_v2 | :broken_heart: (Script) | | | |
|
|
||||||
| SqueezeNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| EfficientNet | :green_heart: (Script) | | | |
|
|
||||||
| Regnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| Resnest | :broken_heart: (Script) | | | |
|
|
||||||
| Vision Transformer | :green_heart: (Script) | | | |
|
|
||||||
| VGG 16 | :green_heart: (Script) | :green_heart: | :green_heart: | |
|
|
||||||
| Wide Resnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| RAFT | :broken_heart: (JIT) | | | |
|
|
||||||
|
|
||||||
For more information refer to [MODEL TRACKING SHEET](https://docs.google.com/spreadsheets/d/15PcjKeHZIrB5LfDyuw7DGEEE8XnQEX2aX8lm8qbxV8A/edit#gid=0)
|
|
||||||
|
|
||||||
### PyTorch Training Models
|
|
||||||
|
|
||||||
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
|
||||||
|---------------------|----------------------|----------|----------|-------------|
|
|
||||||
| BERT | :broken_heart: | :broken_heart: | | |
|
|
||||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>JAX Models</summary>
|
|
||||||
|
|
||||||
|
|
||||||
### JAX Models
|
|
||||||
|
|
||||||
| Models | JAX-MHLO lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
|
||||||
|---------------------|----------------------|----------|----------|-------------|
|
|
||||||
| DALL-E | :broken_heart: | :broken_heart: | | |
|
|
||||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>TFLite Models</summary>
|
|
||||||
|
|
||||||
### TFLite Models
|
|
||||||
|
|
||||||
| Models | TOSA/LinAlg | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
|
||||||
|---------------------|----------------------|----------|----------|-------------|
|
|
||||||
| BERT | :broken_heart: | :broken_heart: | | |
|
|
||||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
|
||||||
| albert | :green_heart: | :green_heart: | | |
|
|
||||||
| asr_conformer | :green_heart: | :green_heart: | | |
|
|
||||||
| bird_classifier | :green_heart: | :green_heart: | | |
|
|
||||||
| cartoon_gan | :green_heart: | :green_heart: | | |
|
|
||||||
| craft_text | :green_heart: | :green_heart: | | |
|
|
||||||
| deeplab_v3 | :green_heart: | :green_heart: | | |
|
|
||||||
| densenet | :green_heart: | :green_heart: | | |
|
|
||||||
| east_text_detector | :green_heart: | :green_heart: | | |
|
|
||||||
| efficientnet_lite0_int8 | :green_heart: | :green_heart: | | |
|
|
||||||
| efficientnet | :green_heart: | :green_heart: | | |
|
|
||||||
| gpt2 | :green_heart: | :green_heart: | | |
|
|
||||||
| image_stylization | :green_heart: | :green_heart: | | |
|
|
||||||
| inception_v4 | :green_heart: | :green_heart: | | |
|
|
||||||
| inception_v4_uint8 | :green_heart: | :green_heart: | | |
|
|
||||||
| lightning_fp16 | :green_heart: | :green_heart: | | |
|
|
||||||
| lightning_i8 | :green_heart: | :green_heart: | | |
|
|
||||||
| lightning | :green_heart: | :green_heart: | | |
|
|
||||||
| magenta | :green_heart: | :green_heart: | | |
|
|
||||||
| midas | :green_heart: | :green_heart: | | |
|
|
||||||
| mirnet | :green_heart: | :green_heart: | | |
|
|
||||||
| mnasnet | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilebert_edgetpu_s_float | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilebert_edgetpu_s_quant | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilebert | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilebert_tf2_float | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilebert_tf2_quant | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilenet_ssd_quant | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilenet_v1 | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilenet_v2 | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilenet_v2_uint8 | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilenet_v3-large | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilenet_v3-large_uint8 | :green_heart: | :green_heart: | | |
|
|
||||||
| mobilenet_v35-int8 | :green_heart: | :green_heart: | | |
|
|
||||||
| nasnet | :green_heart: | :green_heart: | | |
|
|
||||||
| person_detect | :green_heart: | :green_heart: | | |
|
|
||||||
| posenet | :green_heart: | :green_heart: | | |
|
|
||||||
| resnet_50_int8 | :green_heart: | :green_heart: | | |
|
|
||||||
| rosetta | :green_heart: | :green_heart: | | |
|
|
||||||
| spice | :green_heart: | :green_heart: | | |
|
|
||||||
| squeezenet | :green_heart: | :green_heart: | | |
|
|
||||||
| ssd_mobilenet_v1 | :green_heart: | :green_heart: | | |
|
|
||||||
| ssd_mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
|
|
||||||
| ssd_mobilenet_v2_fpnlite | :green_heart: | :green_heart: | | |
|
|
||||||
| ssd_mobilenet_v2_fpnlite_uint8 | :green_heart: | :green_heart: | | |
|
|
||||||
| ssd_mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
|
|
||||||
| ssd_mobilenet_v2 | :green_heart: | :green_heart: | | |
|
|
||||||
| ssd_spaghettinet_large | :green_heart: | :green_heart: | | |
|
|
||||||
| ssd_spaghettinet_large_uint8 | :green_heart: | :green_heart: | | |
|
|
||||||
| visual_wake_words_i8 | :green_heart: | :green_heart: | | |
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>TF Models</summary>
|
|
||||||
|
|
||||||
### Tensorflow Models (Inference)
|
|
||||||
|
|
||||||
| Hugging Face Models | tf-mhlo lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
|
||||||
|---------------------|----------------------|----------|----------|-------------|
|
|
||||||
| BERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| albert-base-v2 | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| DistilBERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| CamemBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| ConvBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| Deberta | | | | |
|
|
||||||
| electra | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| funnel | | | | |
|
|
||||||
| layoutlm | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| longformer | | | | |
|
|
||||||
| mobile-bert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| remembert | | | | |
|
|
||||||
| tapas | | | | |
|
|
||||||
| flaubert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| xlm-roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
| mpnet | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## Related Projects
|
## Related Projects
|
||||||
|
|
||||||
|
|||||||
@@ -205,14 +205,14 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--torch_model_csv",
|
"--torch_model_csv",
|
||||||
type=lambda x: is_valid_file(x),
|
type=lambda x: is_valid_file(x),
|
||||||
default="./tank/pytorch/torch_model_list.csv",
|
default="./tank/torch_model_list.csv",
|
||||||
help="""Contains the file with torch_model name and args.
|
help="""Contains the file with torch_model name and args.
|
||||||
Please see: https://github.com/nod-ai/SHARK/blob/main/tank/pytorch/torch_model_list.csv""",
|
Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tf_model_csv",
|
"--tf_model_csv",
|
||||||
type=lambda x: is_valid_file(x),
|
type=lambda x: is_valid_file(x),
|
||||||
default="./tank/tf/tf_model_list.csv",
|
default="./tank/tf_model_list.csv",
|
||||||
help="Contains the file with tf model name and args.",
|
help="Contains the file with tf model name and args.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
222
tank/README.md
222
tank/README.md
@@ -1,3 +1,72 @@
|
|||||||
|
## SHARK Tank
|
||||||
|
<details>
|
||||||
|
<summary>Testing and Benchmarks</summary>
|
||||||
|
|
||||||
|
### Run all model tests on CPU/GPU/VULKAN/Metal
|
||||||
|
```shell
|
||||||
|
pytest tank/test_models.py
|
||||||
|
|
||||||
|
# Models included in the pytest suite can be found listed in all_models.csv.
|
||||||
|
|
||||||
|
# If on Linux for multithreading on CPU (faster results):
|
||||||
|
pytest tank/test_models.py -n auto
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running specific tests
|
||||||
|
```shell
|
||||||
|
|
||||||
|
# Search for test cases by including a keyword that matches all or part of the test case's name;
|
||||||
|
pytest tank/test_models.py -k "keyword"
|
||||||
|
|
||||||
|
# Test cases are named uniformly by format test_module_<model_name_underscores_only>_<torch/tf>_<static/dynamic>_<device>.
|
||||||
|
|
||||||
|
# Example: Test all models on nvidia gpu:
|
||||||
|
pytest tank/test_models.py -k "cuda"
|
||||||
|
|
||||||
|
# Example: Test all tensorflow resnet models on Vulkan backend:
|
||||||
|
pytest tank/test_models.py -k "resnet and tf and vulkan"
|
||||||
|
|
||||||
|
# Exclude a test case:
|
||||||
|
pytest tank/test_models.py -k "not ..."
|
||||||
|
|
||||||
|
### Run benchmarks on SHARK tank pytests and generate bench_results.csv with results.
|
||||||
|
|
||||||
|
(the following requires source installation with `IMPORTER=1 ./setup_venv.sh`)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pytest --benchmark tank/test_models.py
|
||||||
|
|
||||||
|
# Just do static GPU benchmarks for PyTorch tests:
|
||||||
|
pytest --benchmark tank/test_models.py -k "pytorch and static and cuda"
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark Resnet50, MiniLM on CPU
|
||||||
|
|
||||||
|
(requires source installation with `IMPORTER=1 ./setup_venv.sh`)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# We suggest running the following commands as root before running benchmarks on CPU:
|
||||||
|
|
||||||
|
cat /sys/devices/system/cpu/cpu*/topology/thread_siblings_list | awk -F, '{print $2}' | sort -n | uniq | ( while read X ; do echo $X ; echo 0 > /sys/devices/system/cpu/cpu$X/online ; done )
|
||||||
|
echo 1 > /sys/devices/system/cpu/intel_pstate/no_turbo
|
||||||
|
|
||||||
|
# Benchmark canonical Resnet50 on CPU via pytest
|
||||||
|
pytest --benchmark tank/test_models -k "resnet50 and tf_static_cpu"
|
||||||
|
|
||||||
|
# Benchmark canonical MiniLM on CPU via pytest
|
||||||
|
pytest --benchmark tank/test_models -k "MiniLM and cpu"
|
||||||
|
|
||||||
|
# Benchmark MiniLM on CPU via transformer-benchmarks:
|
||||||
|
git clone --recursive https://github.com/nod-ai/transformer-benchmarks.git
|
||||||
|
cd transformer-benchmarks
|
||||||
|
./perf-ci.sh -n
|
||||||
|
# Check detail.csv for MLIR/IREE results.
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
To run the fine tuning example, from the root SHARK directory, run:
|
To run the fine tuning example, from the root SHARK directory, run:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -11,3 +80,156 @@ if running from a google vm, you can view jupyter notebooks on your local system
|
|||||||
gcloud compute ssh <YOUR_INSTANCE_DETAILS> --ssh-flag="-N -L localhost:8888:localhost:8888"
|
gcloud compute ssh <YOUR_INSTANCE_DETAILS> --ssh-flag="-N -L localhost:8888:localhost:8888"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Supported and Validated Models
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>PyTorch Models</summary>
|
||||||
|
|
||||||
|
### Huggingface PyTorch Models
|
||||||
|
|
||||||
|
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||||
|
|---------------------|----------------------|----------|----------|-------------|
|
||||||
|
| BERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| Albert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| BigBird | :green_heart: (AOT) | | | |
|
||||||
|
| DistilBERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| GPT2 | :broken_heart: (AOT) | | | |
|
||||||
|
| MobileBert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
|
||||||
|
### Torchvision Models
|
||||||
|
|
||||||
|
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||||
|
|--------------------|----------------------|----------|----------|-------------|
|
||||||
|
| AlexNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| DenseNet121 | :green_heart: (Script) | | | |
|
||||||
|
| MNasNet1_0 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| MobileNetV2 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| MobileNetV3 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| Unet | :broken_heart: (Script) | | | |
|
||||||
|
| Resnet18 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| Resnet50 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| Resnet101 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| Resnext50_32x4d | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| ShuffleNet_v2 | :broken_heart: (Script) | | | |
|
||||||
|
| SqueezeNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| EfficientNet | :green_heart: (Script) | | | |
|
||||||
|
| Regnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| Resnest | :broken_heart: (Script) | | | |
|
||||||
|
| Vision Transformer | :green_heart: (Script) | | | |
|
||||||
|
| VGG 16 | :green_heart: (Script) | :green_heart: | :green_heart: | |
|
||||||
|
| Wide Resnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| RAFT | :broken_heart: (JIT) | | | |
|
||||||
|
|
||||||
|
For more information refer to [MODEL TRACKING SHEET](https://docs.google.com/spreadsheets/d/15PcjKeHZIrB5LfDyuw7DGEEE8XnQEX2aX8lm8qbxV8A/edit#gid=0)
|
||||||
|
|
||||||
|
### PyTorch Training Models
|
||||||
|
|
||||||
|
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||||
|
|---------------------|----------------------|----------|----------|-------------|
|
||||||
|
| BERT | :broken_heart: | :broken_heart: | | |
|
||||||
|
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>JAX Models</summary>
|
||||||
|
|
||||||
|
|
||||||
|
### JAX Models
|
||||||
|
|
||||||
|
| Models | JAX-MHLO lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||||
|
|---------------------|----------------------|----------|----------|-------------|
|
||||||
|
| DALL-E | :broken_heart: | :broken_heart: | | |
|
||||||
|
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>TFLite Models</summary>
|
||||||
|
|
||||||
|
### TFLite Models
|
||||||
|
|
||||||
|
| Models | TOSA/LinAlg | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||||
|
|---------------------|----------------------|----------|----------|-------------|
|
||||||
|
| BERT | :broken_heart: | :broken_heart: | | |
|
||||||
|
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||||
|
| albert | :green_heart: | :green_heart: | | |
|
||||||
|
| asr_conformer | :green_heart: | :green_heart: | | |
|
||||||
|
| bird_classifier | :green_heart: | :green_heart: | | |
|
||||||
|
| cartoon_gan | :green_heart: | :green_heart: | | |
|
||||||
|
| craft_text | :green_heart: | :green_heart: | | |
|
||||||
|
| deeplab_v3 | :green_heart: | :green_heart: | | |
|
||||||
|
| densenet | :green_heart: | :green_heart: | | |
|
||||||
|
| east_text_detector | :green_heart: | :green_heart: | | |
|
||||||
|
| efficientnet_lite0_int8 | :green_heart: | :green_heart: | | |
|
||||||
|
| efficientnet | :green_heart: | :green_heart: | | |
|
||||||
|
| gpt2 | :green_heart: | :green_heart: | | |
|
||||||
|
| image_stylization | :green_heart: | :green_heart: | | |
|
||||||
|
| inception_v4 | :green_heart: | :green_heart: | | |
|
||||||
|
| inception_v4_uint8 | :green_heart: | :green_heart: | | |
|
||||||
|
| lightning_fp16 | :green_heart: | :green_heart: | | |
|
||||||
|
| lightning_i8 | :green_heart: | :green_heart: | | |
|
||||||
|
| lightning | :green_heart: | :green_heart: | | |
|
||||||
|
| magenta | :green_heart: | :green_heart: | | |
|
||||||
|
| midas | :green_heart: | :green_heart: | | |
|
||||||
|
| mirnet | :green_heart: | :green_heart: | | |
|
||||||
|
| mnasnet | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilebert_edgetpu_s_float | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilebert_edgetpu_s_quant | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilebert | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilebert_tf2_float | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilebert_tf2_quant | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilenet_ssd_quant | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilenet_v1 | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilenet_v2 | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilenet_v2_uint8 | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilenet_v3-large | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilenet_v3-large_uint8 | :green_heart: | :green_heart: | | |
|
||||||
|
| mobilenet_v35-int8 | :green_heart: | :green_heart: | | |
|
||||||
|
| nasnet | :green_heart: | :green_heart: | | |
|
||||||
|
| person_detect | :green_heart: | :green_heart: | | |
|
||||||
|
| posenet | :green_heart: | :green_heart: | | |
|
||||||
|
| resnet_50_int8 | :green_heart: | :green_heart: | | |
|
||||||
|
| rosetta | :green_heart: | :green_heart: | | |
|
||||||
|
| spice | :green_heart: | :green_heart: | | |
|
||||||
|
| squeezenet | :green_heart: | :green_heart: | | |
|
||||||
|
| ssd_mobilenet_v1 | :green_heart: | :green_heart: | | |
|
||||||
|
| ssd_mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
|
||||||
|
| ssd_mobilenet_v2_fpnlite | :green_heart: | :green_heart: | | |
|
||||||
|
| ssd_mobilenet_v2_fpnlite_uint8 | :green_heart: | :green_heart: | | |
|
||||||
|
| ssd_mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
|
||||||
|
| ssd_mobilenet_v2 | :green_heart: | :green_heart: | | |
|
||||||
|
| ssd_spaghettinet_large | :green_heart: | :green_heart: | | |
|
||||||
|
| ssd_spaghettinet_large_uint8 | :green_heart: | :green_heart: | | |
|
||||||
|
| visual_wake_words_i8 | :green_heart: | :green_heart: | | |
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>TF Models</summary>
|
||||||
|
|
||||||
|
### Tensorflow Models (Inference)
|
||||||
|
|
||||||
|
| Hugging Face Models | tf-mhlo lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||||
|
|---------------------|----------------------|----------|----------|-------------|
|
||||||
|
| BERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| albert-base-v2 | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| DistilBERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| CamemBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| ConvBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| Deberta | | | | |
|
||||||
|
| electra | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| funnel | | | | |
|
||||||
|
| layoutlm | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| longformer | | | | |
|
||||||
|
| mobile-bert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| remembert | | | | |
|
||||||
|
| tapas | | | | |
|
||||||
|
| flaubert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| xlm-roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
| mpnet | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|||||||
83
tank/examples/bert_tf/seq_classification.py
Executable file
83
tank/examples/bert_tf/seq_classification.py
Executable file
@@ -0,0 +1,83 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
import tensorflow as tf
|
||||||
|
from shark.shark_inference import SharkInference
|
||||||
|
from shark.parser import shark_args
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
seq_parser = argparse.ArgumentParser(
|
||||||
|
description="Shark Sequence Classification."
|
||||||
|
)
|
||||||
|
seq_parser.add_argument(
|
||||||
|
"--hf_model_name",
|
||||||
|
type=str,
|
||||||
|
default="bert-base-uncased",
|
||||||
|
help="Hugging face model to run sequence classification.",
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_args, unknown = seq_parser.parse_known_args()
|
||||||
|
|
||||||
|
|
||||||
|
BATCH_SIZE = 1
|
||||||
|
MAX_SEQUENCE_LENGTH = 16
|
||||||
|
|
||||||
|
# Create a set of input signature.
|
||||||
|
inputs_signature = [
|
||||||
|
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||||
|
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||||
|
]
|
||||||
|
|
||||||
|
# For supported models please see here:
|
||||||
|
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForSequenceClassification
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_input(text="This is just used to compile the model"):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(seq_args.hf_model_name)
|
||||||
|
inputs = tokenizer(
|
||||||
|
text,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="tf",
|
||||||
|
truncation=True,
|
||||||
|
max_length=MAX_SEQUENCE_LENGTH,
|
||||||
|
)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
class SeqClassification(tf.Module):
|
||||||
|
def __init__(self, model_name):
|
||||||
|
super(SeqClassification, self).__init__()
|
||||||
|
self.m = TFAutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_name, output_attentions=False, num_labels=2
|
||||||
|
)
|
||||||
|
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
|
||||||
|
|
||||||
|
@tf.function(input_signature=inputs_signature)
|
||||||
|
def forward(self, input_ids, attention_mask):
|
||||||
|
return tf.math.softmax(
|
||||||
|
self.m.predict(input_ids, attention_mask), axis=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
inputs = preprocess_input()
|
||||||
|
shark_module = SharkInference(
|
||||||
|
SeqClassification(seq_args.hf_model_name),
|
||||||
|
(inputs["input_ids"], inputs["attention_mask"]),
|
||||||
|
)
|
||||||
|
shark_module.set_frontend("tensorflow")
|
||||||
|
shark_module.compile()
|
||||||
|
print(f"Model has been successfully compiled on {shark_args.device}")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
input_text = input(
|
||||||
|
"Enter the text to classify (press q or nothing to exit): "
|
||||||
|
)
|
||||||
|
if not input_text or input_text == "q":
|
||||||
|
break
|
||||||
|
inputs = preprocess_input(input_text)
|
||||||
|
print(
|
||||||
|
shark_module.forward(
|
||||||
|
(inputs["input_ids"], inputs["attention_mask"])
|
||||||
|
)
|
||||||
|
)
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
## Running SharkInference on CPUs, GPUs and MAC.
|
|
||||||
|
|
||||||
|
|
||||||
### Run the binary sequence_classification.
|
|
||||||
#### The models supported are: [hugging face sequence classification](https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForSequenceClassification)
|
|
||||||
```shell
|
|
||||||
./seq_classification.py --hf_model_name="hf_model" --device="cpu" # Use gpu | vulkan
|
|
||||||
```
|
|
||||||
|
|
||||||
Once the model is compiled to run on the device mentioned, we can pass in text and
|
|
||||||
get the logits.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
hf-internal-testing/tiny-random-flaubert,hf
|
|
||||||
|
Reference in New Issue
Block a user