mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
79 Commits
20230812.8
...
20230928.9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d7399c8ee7 | ||
|
|
b6f8993dcc | ||
|
|
94594542a9 | ||
|
|
82f833e87d | ||
|
|
c9d6870105 | ||
|
|
4fec03a6cc | ||
|
|
9a27f51378 | ||
|
|
ad1a0f35ff | ||
|
|
6773278ec2 | ||
|
|
9a0efffcca | ||
|
|
61c6f153d9 | ||
|
|
effd42e8f5 | ||
|
|
b5fbb1a8a0 | ||
|
|
ded74d09cd | ||
|
|
79267931c1 | ||
|
|
9eceba69b7 | ||
|
|
ca609afb6a | ||
|
|
11bdce9790 | ||
|
|
684943a4a6 | ||
|
|
b817bb8455 | ||
|
|
780f520f02 | ||
|
|
c61b6f8d65 | ||
|
|
c854208d49 | ||
|
|
c5dcfc1f13 | ||
|
|
bde63ee8ae | ||
|
|
9681d494eb | ||
|
|
ede6bf83e2 | ||
|
|
2c2693fb7d | ||
|
|
1d31b2b2c6 | ||
|
|
d2f64eefa3 | ||
|
|
87ae14b6ff | ||
|
|
1ccafa1fc1 | ||
|
|
4c3d8a0a7f | ||
|
|
3601dc7c3b | ||
|
|
671881cf87 | ||
|
|
4e9be6be59 | ||
|
|
9c8cbaf498 | ||
|
|
9e348a114e | ||
|
|
51f90a4d56 | ||
|
|
310d5d0a49 | ||
|
|
9697981004 | ||
|
|
450c231171 | ||
|
|
07f6f4a2f7 | ||
|
|
610813c72f | ||
|
|
8e3860c9e6 | ||
|
|
e37d6720eb | ||
|
|
16160d9a7d | ||
|
|
79075a1a07 | ||
|
|
db990826d3 | ||
|
|
7ee3e4ba5d | ||
|
|
05889a8fe1 | ||
|
|
b87efe7686 | ||
|
|
82b462de3a | ||
|
|
d8f0f7bade | ||
|
|
79bd0b84a1 | ||
|
|
8738571d1e | ||
|
|
a4c354ce54 | ||
|
|
cc53efa89f | ||
|
|
9ae8bc921e | ||
|
|
32eb78f0f9 | ||
|
|
cb509343d9 | ||
|
|
6da391c9b1 | ||
|
|
9dee7ae652 | ||
|
|
343dfd901c | ||
|
|
57260b9c37 | ||
|
|
18e7d2d061 | ||
|
|
51a1009796 | ||
|
|
045c3c3852 | ||
|
|
0139dd58d9 | ||
|
|
c96571855a | ||
|
|
4f61d69d86 | ||
|
|
531d447768 | ||
|
|
16f46f8de9 | ||
|
|
c4723f469f | ||
|
|
d804f45a61 | ||
|
|
d22177f936 | ||
|
|
75e68f02f4 | ||
|
|
4dc9c59611 | ||
|
|
18801dcabc |
10
.github/workflows/nightly.yml
vendored
10
.github/workflows/nightly.yml
vendored
@@ -51,11 +51,11 @@ jobs:
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
@@ -104,7 +104,7 @@ jobs:
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
@@ -134,6 +134,8 @@ jobs:
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
|
||||
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
|
||||
fi
|
||||
export SHA=$(git log -1 --format='%h')
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
|
||||
rm -rf ./wheelhouse/nodai*
|
||||
|
||||
- name: Build and validate the SHARK Runtime package
|
||||
@@ -144,7 +146,7 @@ jobs:
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -193,3 +193,9 @@ stencil_annotator/
|
||||
# For DocuChat
|
||||
apps/language_models/langchain/user_path/
|
||||
db_dir_UserData
|
||||
|
||||
# Embeded browser cache and other
|
||||
apps/stable_diffusion/web/EBWebView/
|
||||
|
||||
# Llama2 tokenizer configs
|
||||
llama2_tokenizer_configs/
|
||||
|
||||
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -1,4 +1,4 @@
|
||||
[submodule "inference/thirdparty/shark-runtime"]
|
||||
path = inference/thirdparty/shark-runtime
|
||||
url =https://github.com/nod-ai/SHARK-Runtime.git
|
||||
url =https://github.com/nod-ai/SRT.git
|
||||
branch = shark-06032022
|
||||
|
||||
@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
|
||||
<summary>Prerequisites - Drivers </summary>
|
||||
|
||||
#### Install your Windows hardware drivers
|
||||
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
|
||||
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
|
||||
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
@@ -170,7 +170,7 @@ python -m pip install --upgrade pip
|
||||
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
|
||||
|
||||
```shell
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
|
||||
### Run shark tank model tests.
|
||||
|
||||
@@ -29,14 +29,8 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
lhs: List[int],
|
||||
rhs: List[int],
|
||||
rhs_scale: List[int],
|
||||
rhs_zero_point: List[int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> List[int]:
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
@@ -45,30 +39,21 @@ def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(
|
||||
lhs_rank_dtype: Tuple[int, int],
|
||||
rhs_rank_dtype: Tuple[int, int],
|
||||
rhs_scale_rank_dtype: Tuple[int, int],
|
||||
rhs_zero_point_rank_dtype: Tuple[int, int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> int:
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
|
||||
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
|
||||
) -> None:
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics,
|
||||
]
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
|
||||
global_device = "cuda"
|
||||
global_precision = "fp16"
|
||||
@@ -244,7 +229,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
@@ -252,7 +237,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -46,6 +46,7 @@ def compile_stableLM(
|
||||
model_vmfb_name,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug=False,
|
||||
):
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
@@ -92,7 +93,7 @@ def compile_stableLM(
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
94
apps/language_models/shark_llama_cli.spec
Normal file
94
apps/language_models/shark_llama_cli.spec
Normal file
@@ -0,0 +1,94 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += copy_metadata('huggingface-hub')
|
||||
datas += copy_metadata('sentencepiece')
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('py-cpuinfo')
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/vicuna.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_llama_cli',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
@@ -47,7 +47,7 @@ from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna,
|
||||
SecondVicuna7B,
|
||||
)
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
@@ -57,8 +57,6 @@ from shark.shark_importer import get_f16_inputs
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
class FirstVicuna(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -21,12 +18,18 @@ class FirstVicuna(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("First Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float32,
|
||||
dtype=torch.float16 if precision == "int4" else torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
@@ -48,7 +51,7 @@ class FirstVicuna(torch.nn.Module):
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna(torch.nn.Module):
|
||||
class SecondVicuna7B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
@@ -64,12 +67,18 @@ class SecondVicuna(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float32,
|
||||
dtype=torch.float16 if precision == "int4" else torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
@@ -148,8 +157,6 @@ class SecondVicuna(torch.nn.Module):
|
||||
i63,
|
||||
i64,
|
||||
):
|
||||
# input_ids = input_tuple[0]
|
||||
# input_tuple = torch.unbind(pkv, dim=0)
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
@@ -290,6 +297,833 @@ class SecondVicuna(torch.nn.Module):
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna13B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="int8",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float16 if precision == "int4" else torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
i0,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
i65,
|
||||
i66,
|
||||
i67,
|
||||
i68,
|
||||
i69,
|
||||
i70,
|
||||
i71,
|
||||
i72,
|
||||
i73,
|
||||
i74,
|
||||
i75,
|
||||
i76,
|
||||
i77,
|
||||
i78,
|
||||
i79,
|
||||
i80,
|
||||
):
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
(
|
||||
i65,
|
||||
i66,
|
||||
),
|
||||
(
|
||||
i67,
|
||||
i68,
|
||||
),
|
||||
(
|
||||
i69,
|
||||
i70,
|
||||
),
|
||||
(
|
||||
i71,
|
||||
i72,
|
||||
),
|
||||
(
|
||||
i73,
|
||||
i74,
|
||||
),
|
||||
(
|
||||
i75,
|
||||
i76,
|
||||
),
|
||||
(
|
||||
i77,
|
||||
i78,
|
||||
),
|
||||
(
|
||||
i79,
|
||||
i80,
|
||||
),
|
||||
)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna70B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float16,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
i0,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
i65,
|
||||
i66,
|
||||
i67,
|
||||
i68,
|
||||
i69,
|
||||
i70,
|
||||
i71,
|
||||
i72,
|
||||
i73,
|
||||
i74,
|
||||
i75,
|
||||
i76,
|
||||
i77,
|
||||
i78,
|
||||
i79,
|
||||
i80,
|
||||
i81,
|
||||
i82,
|
||||
i83,
|
||||
i84,
|
||||
i85,
|
||||
i86,
|
||||
i87,
|
||||
i88,
|
||||
i89,
|
||||
i90,
|
||||
i91,
|
||||
i92,
|
||||
i93,
|
||||
i94,
|
||||
i95,
|
||||
i96,
|
||||
i97,
|
||||
i98,
|
||||
i99,
|
||||
i100,
|
||||
i101,
|
||||
i102,
|
||||
i103,
|
||||
i104,
|
||||
i105,
|
||||
i106,
|
||||
i107,
|
||||
i108,
|
||||
i109,
|
||||
i110,
|
||||
i111,
|
||||
i112,
|
||||
i113,
|
||||
i114,
|
||||
i115,
|
||||
i116,
|
||||
i117,
|
||||
i118,
|
||||
i119,
|
||||
i120,
|
||||
i121,
|
||||
i122,
|
||||
i123,
|
||||
i124,
|
||||
i125,
|
||||
i126,
|
||||
i127,
|
||||
i128,
|
||||
i129,
|
||||
i130,
|
||||
i131,
|
||||
i132,
|
||||
i133,
|
||||
i134,
|
||||
i135,
|
||||
i136,
|
||||
i137,
|
||||
i138,
|
||||
i139,
|
||||
i140,
|
||||
i141,
|
||||
i142,
|
||||
i143,
|
||||
i144,
|
||||
i145,
|
||||
i146,
|
||||
i147,
|
||||
i148,
|
||||
i149,
|
||||
i150,
|
||||
i151,
|
||||
i152,
|
||||
i153,
|
||||
i154,
|
||||
i155,
|
||||
i156,
|
||||
i157,
|
||||
i158,
|
||||
i159,
|
||||
i160,
|
||||
):
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
(
|
||||
i65,
|
||||
i66,
|
||||
),
|
||||
(
|
||||
i67,
|
||||
i68,
|
||||
),
|
||||
(
|
||||
i69,
|
||||
i70,
|
||||
),
|
||||
(
|
||||
i71,
|
||||
i72,
|
||||
),
|
||||
(
|
||||
i73,
|
||||
i74,
|
||||
),
|
||||
(
|
||||
i75,
|
||||
i76,
|
||||
),
|
||||
(
|
||||
i77,
|
||||
i78,
|
||||
),
|
||||
(
|
||||
i79,
|
||||
i80,
|
||||
),
|
||||
(
|
||||
i81,
|
||||
i82,
|
||||
),
|
||||
(
|
||||
i83,
|
||||
i84,
|
||||
),
|
||||
(
|
||||
i85,
|
||||
i86,
|
||||
),
|
||||
(
|
||||
i87,
|
||||
i88,
|
||||
),
|
||||
(
|
||||
i89,
|
||||
i90,
|
||||
),
|
||||
(
|
||||
i91,
|
||||
i92,
|
||||
),
|
||||
(
|
||||
i93,
|
||||
i94,
|
||||
),
|
||||
(
|
||||
i95,
|
||||
i96,
|
||||
),
|
||||
(
|
||||
i97,
|
||||
i98,
|
||||
),
|
||||
(
|
||||
i99,
|
||||
i100,
|
||||
),
|
||||
(
|
||||
i101,
|
||||
i102,
|
||||
),
|
||||
(
|
||||
i103,
|
||||
i104,
|
||||
),
|
||||
(
|
||||
i105,
|
||||
i106,
|
||||
),
|
||||
(
|
||||
i107,
|
||||
i108,
|
||||
),
|
||||
(
|
||||
i109,
|
||||
i110,
|
||||
),
|
||||
(
|
||||
i111,
|
||||
i112,
|
||||
),
|
||||
(
|
||||
i113,
|
||||
i114,
|
||||
),
|
||||
(
|
||||
i115,
|
||||
i116,
|
||||
),
|
||||
(
|
||||
i117,
|
||||
i118,
|
||||
),
|
||||
(
|
||||
i119,
|
||||
i120,
|
||||
),
|
||||
(
|
||||
i121,
|
||||
i122,
|
||||
),
|
||||
(
|
||||
i123,
|
||||
i124,
|
||||
),
|
||||
(
|
||||
i125,
|
||||
i126,
|
||||
),
|
||||
(
|
||||
i127,
|
||||
i128,
|
||||
),
|
||||
(
|
||||
i129,
|
||||
i130,
|
||||
),
|
||||
(
|
||||
i131,
|
||||
i132,
|
||||
),
|
||||
(
|
||||
i133,
|
||||
i134,
|
||||
),
|
||||
(
|
||||
i135,
|
||||
i136,
|
||||
),
|
||||
(
|
||||
i137,
|
||||
i138,
|
||||
),
|
||||
(
|
||||
i139,
|
||||
i140,
|
||||
),
|
||||
(
|
||||
i141,
|
||||
i142,
|
||||
),
|
||||
(
|
||||
i143,
|
||||
i144,
|
||||
),
|
||||
(
|
||||
i145,
|
||||
i146,
|
||||
),
|
||||
(
|
||||
i147,
|
||||
i148,
|
||||
),
|
||||
(
|
||||
i149,
|
||||
i150,
|
||||
),
|
||||
(
|
||||
i151,
|
||||
i152,
|
||||
),
|
||||
(
|
||||
i153,
|
||||
i154,
|
||||
),
|
||||
(
|
||||
i155,
|
||||
i156,
|
||||
),
|
||||
(
|
||||
i157,
|
||||
i158,
|
||||
),
|
||||
(
|
||||
i159,
|
||||
i160,
|
||||
),
|
||||
)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class CombinedModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -298,7 +1132,8 @@ class CombinedModel(torch.nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
|
||||
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
|
||||
# NOT using this path for 13B currently, hence using `SecondVicuna7B`.
|
||||
self.second_vicuna = SecondVicuna7B(second_vicuna_model_path)
|
||||
|
||||
def forward(self, input_ids):
|
||||
first_output = self.first_vicuna(input_ids=input_ids)
|
||||
|
||||
@@ -3,7 +3,10 @@ from abc import ABC, abstractmethod
|
||||
|
||||
class SharkLLMBase(ABC):
|
||||
def __init__(
|
||||
self, model_name, hf_model_path=None, max_num_tokens=512
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path=None,
|
||||
max_num_tokens=512,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.hf_model_path = hf_model_path
|
||||
|
||||
@@ -28,7 +28,9 @@ parser = argparse.ArgumentParser(
|
||||
description="runs a falcon model",
|
||||
)
|
||||
|
||||
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
|
||||
parser.add_argument(
|
||||
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
|
||||
)
|
||||
@@ -49,7 +51,7 @@ parser.add_argument(
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=False,
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
@@ -59,32 +61,52 @@ parser.add_argument(
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model in cli mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication token for falcon-180B model.",
|
||||
)
|
||||
|
||||
|
||||
class Falcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path,
|
||||
hf_model_path="tiiuae/falcon-7b-instruct",
|
||||
hf_auth_token: str = None,
|
||||
max_num_tokens=150,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print("hf_model_path: ", self.hf_model_path)
|
||||
|
||||
if "180b" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
""" HF auth token required for falcon-180b. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.falcon_vmfb_path = falcon_vmfb_path
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, trust_remote_code=True
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
token=self.hf_auth_token,
|
||||
)
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token_id = 11
|
||||
@@ -92,13 +114,17 @@ class Falcon(SharkLLMBase):
|
||||
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
return falcon_model
|
||||
|
||||
def compile_falcon(self):
|
||||
def compile(self):
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
@@ -120,37 +146,39 @@ class Falcon(SharkLLMBase):
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
|
||||
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
if args.load_mlir_from_shark_tank:
|
||||
# Downloading MLIR from shark_tank
|
||||
print(f"[DEBUG] Trying to download mlir from shark_tank")
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(
|
||||
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] generating MLIR locally")
|
||||
compilation_input_ids = torch.randint(
|
||||
low=1, high=10000, size=(1, 100)
|
||||
)
|
||||
@@ -189,10 +217,9 @@ class Falcon(SharkLLMBase):
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
print(f"[DEBUG] writing mlir to file")
|
||||
with open(f"{self.model_name}.mlir", "wb") as f_:
|
||||
with redirect_stdout(f_):
|
||||
print(module.operation.get_asm())
|
||||
f_ = open(self.falcon_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
@@ -202,22 +229,17 @@ class Falcon(SharkLLMBase):
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
self.falcon_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-spirv-index-bits=64",
|
||||
],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile(self):
|
||||
falcon_shark_model = self.compile_falcon()
|
||||
return falcon_shark_model
|
||||
|
||||
def generate(self, prompt):
|
||||
model_inputs = self.tokenizer(
|
||||
prompt,
|
||||
@@ -466,11 +488,16 @@ if __name__ == "__main__":
|
||||
else Path(args.falcon_vmfb_path)
|
||||
)
|
||||
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "tiiuae/falcon-180B-chat"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
|
||||
)
|
||||
|
||||
falcon = Falcon(
|
||||
"falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path="tiiuae/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct",
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
|
||||
@@ -136,7 +136,8 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
@@ -145,20 +146,21 @@ def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
@@ -176,7 +178,7 @@ def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
|
||||
|
||||
def compile_module(
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[]
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[], debug=False,
|
||||
):
|
||||
if generate_vmfb:
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
|
||||
@@ -188,7 +190,7 @@ def compile_module(
|
||||
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_model_name, extra_args
|
||||
os.getcwd(), extended_model_name, extra_args, debug=debug
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -197,7 +199,7 @@ def compile_module(
|
||||
|
||||
|
||||
def compile_int_precision(
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name, debug=False
|
||||
):
|
||||
torchscript_module = import_with_fx(
|
||||
model,
|
||||
@@ -209,7 +211,7 @@ def compile_int_precision(
|
||||
torchscript_module,
|
||||
inputs,
|
||||
output_type="torch",
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
@@ -217,7 +219,7 @@ def compile_int_precision(
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
@@ -249,6 +251,7 @@ def compile_int_precision(
|
||||
extended_model_name=extended_model_name,
|
||||
generate_vmfb=generate_vmfb,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
),
|
||||
bytecode,
|
||||
)
|
||||
@@ -292,6 +295,7 @@ def shark_compile_through_fx_int(
|
||||
device,
|
||||
generate_or_load_vmfb,
|
||||
extended_model_name,
|
||||
debug,
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
|
||||
@@ -32,11 +32,13 @@ class SharkStableLM(SharkLLMBase):
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug="False",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
@@ -111,7 +113,7 @@ class SharkStableLM(SharkLLMBase):
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from shark.shark_downloader import download_public_file
|
||||
|
||||
# expects a Path / str as arg
|
||||
# returns None if path not found or SharkInference module
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
|
||||
if not isinstance(vmfb_path, Path):
|
||||
vmfb_path = Path(vmfb_path)
|
||||
|
||||
@@ -20,7 +20,7 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
print("Device from get_vmfb_from_path - ", device)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect=mlir_dialect
|
||||
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
@@ -28,7 +28,13 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
|
||||
|
||||
def get_vmfb_from_config(
|
||||
shark_container, model, precision, device, vmfb_path, padding=None
|
||||
shark_container,
|
||||
model,
|
||||
precision,
|
||||
device,
|
||||
vmfb_path,
|
||||
padding=None,
|
||||
device_id=None,
|
||||
):
|
||||
vmfb_url = (
|
||||
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
|
||||
@@ -37,4 +43,6 @@ def get_vmfb_from_config(
|
||||
vmfb_url = vmfb_url + f"_{padding}"
|
||||
vmfb_url = vmfb_url + ".vmfb"
|
||||
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
|
||||
return get_vmfb_from_path(vmfb_path, device, "tm_tensor")
|
||||
return get_vmfb_from_path(
|
||||
vmfb_path, device, "tm_tensor", device_id=device_id
|
||||
)
|
||||
|
||||
@@ -7,16 +7,16 @@ Compile Commands FP32/FP16:
|
||||
|
||||
```shell
|
||||
Vulkan AMD:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
|
||||
# use –iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
|
||||
|
||||
CUDA NVIDIA:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
CPU:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu /path/to/input/mlir -o /path/to/output/vmfb
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
||||
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
@@ -287,7 +287,7 @@ def lora_train(
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
@@ -15,8 +15,8 @@ pathex = [
|
||||
|
||||
# datafiles for pyinstaller
|
||||
datas = []
|
||||
datas += collect_data_files("torch")
|
||||
datas += copy_metadata("torch")
|
||||
datas += copy_metadata("tokenizers")
|
||||
datas += copy_metadata("tqdm")
|
||||
datas += copy_metadata("regex")
|
||||
datas += copy_metadata("requests")
|
||||
@@ -30,20 +30,21 @@ datas += copy_metadata("safetensors")
|
||||
datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files("diffusers")
|
||||
datas += collect_data_files("transformers")
|
||||
datas += collect_data_files("pytorch_lightning")
|
||||
datas += collect_data_files("opencv_python")
|
||||
datas += collect_data_files("skimage")
|
||||
datas += collect_data_files("gradio")
|
||||
datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree")
|
||||
datas += collect_data_files("google_cloud_storage")
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
@@ -51,6 +52,7 @@ datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
datas += collect_data_files("cv2")
|
||||
datas += [
|
||||
("src/utils/resources/prompts.json", "resources"),
|
||||
("src/utils/resources/model_db.json", "resources"),
|
||||
@@ -72,7 +74,11 @@ datas += [
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("transformers") if "tests" not in x
|
||||
x
|
||||
for x in collect_submodules("transformers")
|
||||
if not any(kw in x for kw in blacklist)
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]
|
||||
|
||||
@@ -177,9 +177,11 @@ class SharkifyStableDiffusionModel:
|
||||
"unet",
|
||||
"unet512",
|
||||
"stencil_unet",
|
||||
"stencil_unet_512",
|
||||
"vae",
|
||||
"vae_encode",
|
||||
"stencil_adaptor",
|
||||
"stencil_adaptor_512",
|
||||
]
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
@@ -339,7 +341,7 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_vae, vae_mlir
|
||||
|
||||
def get_controlled_unet(self):
|
||||
def get_controlled_unet(self, use_large=False):
|
||||
class ControlledUnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -415,6 +417,16 @@ class SharkifyStableDiffusionModel:
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
model_name = "stencil_unet"
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (
|
||||
inputs[:2]
|
||||
+ (torch.nn.functional.pad(inputs[2], pad),)
|
||||
+ inputs[3:]
|
||||
)
|
||||
model_name = "stencil_unet_512"
|
||||
input_mask = [
|
||||
True,
|
||||
True,
|
||||
@@ -437,19 +449,19 @@ class SharkifyStableDiffusionModel:
|
||||
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["stencil_unet"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="stencil_unet",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_controlled_unet, controlled_unet_mlir
|
||||
|
||||
def get_control_net(self):
|
||||
def get_control_net(self, use_large=False):
|
||||
class StencilControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.use_stencil, low_cpu_mem_usage=False
|
||||
@@ -497,17 +509,34 @@ class SharkifyStableDiffusionModel:
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["stencil_adaptor"])
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3],
|
||||
)
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
|
||||
)
|
||||
else:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor"]
|
||||
)
|
||||
input_mask = [True, True, True, True]
|
||||
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
|
||||
shark_cnet, cnet_mlir = compile_through_fx(
|
||||
scnet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["stencil_adaptor"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="stencil_adaptor",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
@@ -748,7 +777,7 @@ class SharkifyStableDiffusionModel:
|
||||
else:
|
||||
return self.get_unet(use_large=use_large)
|
||||
else:
|
||||
return self.get_controlled_unet()
|
||||
return self.get_controlled_unet(use_large=use_large)
|
||||
|
||||
def vae_encode(self):
|
||||
try:
|
||||
@@ -847,12 +876,14 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def controlnet(self):
|
||||
def controlnet(self, use_large=False):
|
||||
try:
|
||||
self.inputs["stencil_adaptor"] = self.get_input_info_for(
|
||||
base_models["stencil_adaptor"]
|
||||
)
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
|
||||
use_large=use_large
|
||||
)
|
||||
|
||||
check_compilation(compiled_stencil_adaptor, "Stencil")
|
||||
if self.return_mlir:
|
||||
|
||||
@@ -84,13 +84,35 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
num_inference_steps,
|
||||
strength,
|
||||
dtype,
|
||||
resample_type,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre process image
|
||||
image = image.resize((width, height))
|
||||
# Pre-process image
|
||||
if resample_type == "Lanczos":
|
||||
resample_type = Image.LANCZOS
|
||||
elif resample_type == "Nearest Neighbor":
|
||||
resample_type = Image.NEAREST
|
||||
elif resample_type == "Bilinear":
|
||||
resample_type = Image.BILINEAR
|
||||
elif resample_type == "Bicubic":
|
||||
resample_type = Image.BICUBIC
|
||||
elif resample_type == "Adaptive":
|
||||
resample_type = Image.ADAPTIVE
|
||||
elif resample_type == "Antialias":
|
||||
resample_type = Image.ANTIALIAS
|
||||
elif resample_type == "Box":
|
||||
resample_type = Image.BOX
|
||||
elif resample_type == "Affine":
|
||||
resample_type = Image.AFFINE
|
||||
elif resample_type == "Cubic":
|
||||
resample_type = Image.CUBIC
|
||||
else: # Fallback to Lanczos
|
||||
resample_type = Image.LANCZOS
|
||||
|
||||
image = image.resize((width, height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
@@ -147,6 +169,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
resample_type,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -186,6 +209,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
|
||||
@@ -58,6 +58,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.controlnet = None
|
||||
self.controlnet_512 = None
|
||||
|
||||
def load_controlnet(self):
|
||||
if self.controlnet is not None:
|
||||
@@ -68,6 +69,15 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
del self.controlnet
|
||||
self.controlnet = None
|
||||
|
||||
def load_controlnet_512(self):
|
||||
if self.controlnet_512 is not None:
|
||||
return
|
||||
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
|
||||
|
||||
def unload_controlnet_512(self):
|
||||
del self.controlnet_512
|
||||
self.controlnet_512 = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
@@ -111,8 +121,12 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.load_unet()
|
||||
self.load_controlnet()
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
self.load_controlnet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
self.load_controlnet_512()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype)
|
||||
@@ -135,43 +149,82 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
).to(dtype)
|
||||
else:
|
||||
latent_model_input_1 = latent_model_input
|
||||
control = self.controlnet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
control = self.controlnet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
control = self.controlnet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
timestep = timestep.detach().numpy()
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
print(self.unet_512)
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
@@ -191,7 +244,9 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
self.unload_controlnet()
|
||||
self.unload_controlnet_512()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -218,6 +273,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
resample_type,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# TODO: 1. Change `num_images_per_prompt`.
|
||||
|
||||
@@ -109,7 +109,7 @@ def load_lower_configs(base_model_id=None):
|
||||
spec = spec.split("-")[0]
|
||||
|
||||
if args.annotation_model == "vae":
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
if not spec or spec in ["sm_80"]:
|
||||
config_name = (
|
||||
f"{args.annotation_model}_{args.precision}_{device}.json"
|
||||
)
|
||||
@@ -158,9 +158,9 @@ def load_lower_configs(base_model_id=None):
|
||||
f"{spec}.json"
|
||||
)
|
||||
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading lowering config file from ", lowering_config_dir)
|
||||
full_gs_url = config_bucket + config_name
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
return lowering_config_dir
|
||||
|
||||
|
||||
@@ -132,6 +132,57 @@ p.add_argument(
|
||||
"img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_hiresfix",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Use Hires Fix to do higher resolution images, while trying to "
|
||||
"avoid the issues that come with it. This is accomplished by first "
|
||||
"generating an image using txt2img, then running it through img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_height",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The height of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_width",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The width of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_strength",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="The denoising strength to apply for the Hires Fix.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--resample_type",
|
||||
type=str,
|
||||
default="Nearest Neighbor",
|
||||
choices=[
|
||||
"Lanczos",
|
||||
"Nearest Neighbor",
|
||||
"Bilinear",
|
||||
"Bicubic",
|
||||
"Adaptive",
|
||||
"Antialias",
|
||||
"Box",
|
||||
"Affine",
|
||||
"Cubic",
|
||||
],
|
||||
help="The resample type to use when resizing an image before being run "
|
||||
"through stable diffusion.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Stable Diffusion Training Params
|
||||
##############################################################################
|
||||
@@ -519,6 +570,14 @@ p.add_argument(
|
||||
"in shark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--compile_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to toggle debug assert/verify flags for imported IR in the"
|
||||
"iree-compiler. Default to false.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_constant_folding",
|
||||
default=True,
|
||||
@@ -574,6 +633,13 @@ p.add_argument(
|
||||
help="Flag for enabling rest API.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling debugging log in WebUI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery",
|
||||
default=True,
|
||||
|
||||
@@ -25,7 +25,7 @@ from shark.iree_utils.vulkan_utils import (
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
from shark.iree_utils.metal_utils import get_metal_target_triple
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc, get_iree_rocm_args
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
@@ -78,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), model_name, extra_args
|
||||
os.getcwd(), model_name, extra_args, debug=args.compile_debug
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -470,12 +470,25 @@ def get_available_devices():
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
vulkan_devices = get_devices_by_name("vulkan")
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
@@ -499,7 +512,10 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
if "rocm" in args.device:
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
print(iree_flags)
|
||||
if args.iree_constant_folding == False:
|
||||
iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
iree_flags.append(
|
||||
@@ -572,7 +588,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
)
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=custom_weights,
|
||||
checkpoint_path_or_dict=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
num_in_channels=num_in_channels,
|
||||
@@ -822,6 +838,8 @@ def clear_all():
|
||||
elif os.name == "unix":
|
||||
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
if args.local_tank_cache != "":
|
||||
shutil.rmtree(args.local_tank_cache)
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
if sys.platform == "darwin":
|
||||
# import before IREE to avoid torch-MLIR library issues
|
||||
@@ -37,10 +38,12 @@ def launch_app(address):
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.debug:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
if args.api or "api" in args.ui.split(","):
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
import PIL
|
||||
from math import ceil
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
@@ -67,6 +68,7 @@ def img2img_inf(
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -245,7 +247,7 @@ def img2img_inf(
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
ceil(steps / strength),
|
||||
strength,
|
||||
guidance_scale,
|
||||
seeds[current_batch],
|
||||
@@ -255,6 +257,7 @@ def img2img_inf(
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(
|
||||
@@ -348,6 +351,7 @@ def img2img_api(
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
resample_type="Lanczos",
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
@@ -432,7 +436,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
# TODO: make this import image prompt info if it exists
|
||||
img2img_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
source="upload",
|
||||
@@ -550,15 +554,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
width = gr.Slider(
|
||||
384, 768, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
@@ -581,11 +576,35 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
step=0.01,
|
||||
label="Denoising Strength",
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
choices=[
|
||||
"Lanczos",
|
||||
"Nearest Neighbor",
|
||||
"Bilinear",
|
||||
"Bicubic",
|
||||
"Adaptive",
|
||||
"Antialias",
|
||||
"Box",
|
||||
"Affine",
|
||||
"Cubic",
|
||||
],
|
||||
label="Resample Type",
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
@@ -695,6 +714,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output, img2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
|
||||
@@ -109,7 +109,7 @@ with gr.Blocks() as minigpt4_web:
|
||||
gr.Markdown(description)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
with gr.Column():
|
||||
image = gr.Image(type="pil")
|
||||
upload_button = gr.Button(
|
||||
value="Upload & Start Chat",
|
||||
|
||||
@@ -8,6 +8,7 @@ from transformers import (
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import time
|
||||
|
||||
|
||||
def user(message, history):
|
||||
@@ -23,12 +24,9 @@ past_key_values = None
|
||||
|
||||
model_map = {
|
||||
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"llama2_13b": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"codegen": "Salesforce/codegen25-7b-multi",
|
||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"vicuna4": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
|
||||
}
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
@@ -42,6 +40,15 @@ start_message = {
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"llama2_13b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
@@ -51,60 +58,39 @@ start_message = {
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"StableLM": (
|
||||
"<|SYSTEM|># StableLM Tuned (Alpha version)"
|
||||
"\n- StableLM is a helpful and harmless open-source AI language model "
|
||||
"developed by StabilityAI."
|
||||
"\n- StableLM is excited to be able to help the user, but will refuse "
|
||||
"to do anything that could be considered harmful to the user."
|
||||
"\n- StableLM is more than just an information source, StableLM is also "
|
||||
"able to write poetry, short stories, and make jokes."
|
||||
"\n- StableLM will refuse to participate in anything that "
|
||||
"could harm a human."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna4": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna1p3": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"codegen": "",
|
||||
}
|
||||
|
||||
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if model_name in [
|
||||
"StableLM",
|
||||
"vicuna",
|
||||
"vicuna4",
|
||||
"vicuna1p3",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
if "llama2" in model_name:
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
conversation = "".join(
|
||||
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
|
||||
)
|
||||
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
elif model_name in ["vicuna"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
else:
|
||||
conversation = "".join(
|
||||
["".join([item[0], item[1]]) for item in history]
|
||||
)
|
||||
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
@@ -138,118 +124,140 @@ def get_default_config():
|
||||
c.split_into_layers()
|
||||
|
||||
|
||||
model_vmfb_key = ""
|
||||
|
||||
|
||||
# TODO: Make chat reusable for UI and API
|
||||
def chat(
|
||||
curr_system_message,
|
||||
history,
|
||||
model,
|
||||
devices,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
cli=True,
|
||||
cli=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global past_key_values
|
||||
|
||||
global model_vmfb_key
|
||||
global vicuna_model
|
||||
|
||||
device_id = None
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
if model_name in [
|
||||
"vicuna",
|
||||
"vicuna4",
|
||||
"vicuna1p3",
|
||||
"codegen",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if vicuna_model == 0:
|
||||
device = devices[0]
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{precision}"
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
if device == "vulkan":
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if vulkan_target_triple == "":
|
||||
# We already have the device_id extracted via WebUI, so we directly use
|
||||
# that to find the target triple.
|
||||
vulkan_target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[device_id]
|
||||
)
|
||||
else:
|
||||
if len(devices) == 1 and config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
)
|
||||
else:
|
||||
if config_file is not None:
|
||||
config_file = open(config_file)
|
||||
config_json = json.load(config_file)
|
||||
config_file.close()
|
||||
else:
|
||||
config_json = get_default_config()
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
device=device,
|
||||
precision=precision,
|
||||
config_json=config_json,
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={vulkan_target_triple}"
|
||||
)
|
||||
if "rdna" in vulkan_target_triple:
|
||||
flags_to_add = [
|
||||
"--iree-spirv-index-bits=64",
|
||||
]
|
||||
_extra_args = _extra_args + flags_to_add
|
||||
|
||||
if device_id is None:
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[id]
|
||||
)
|
||||
if target_triple == vulkan_target_triple:
|
||||
device_id = id
|
||||
break
|
||||
id += 1
|
||||
|
||||
prompt = create_prompt(model_name, history)
|
||||
assert (
|
||||
device_id
|
||||
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
|
||||
for partial_text in vicuna_model.generate(prompt, cli=cli):
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
print(f"Will use target triple : {vulkan_target_triple}")
|
||||
|
||||
return history
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=download_vmfb,
|
||||
load_mlir_from_shark_tank=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
# else Model is StableLM
|
||||
global sharkModel
|
||||
from apps.language_models.src.pipelines.stablelm_pipeline import (
|
||||
SharkStableLM,
|
||||
)
|
||||
|
||||
if sharkModel == 0:
|
||||
# max_new_tokens=512
|
||||
shark_slm = SharkStableLM(
|
||||
model_name
|
||||
) # pass elements from UI as required
|
||||
|
||||
# Construct the input message string for the model by concatenating the
|
||||
# current system message and conversation history
|
||||
if len(curr_system_message.split()) > 160:
|
||||
print("clearing context")
|
||||
prompt = create_prompt(model_name, history)
|
||||
generate_kwargs = dict(prompt=prompt)
|
||||
|
||||
words_list = shark_slm.generate(**generate_kwargs)
|
||||
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
print(new_text)
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to clean up the message textbox and the updated
|
||||
# conversation history
|
||||
yield history
|
||||
return words_list
|
||||
count = 0
|
||||
start_time = time.time()
|
||||
for text, msg in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
count += 1
|
||||
if "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
end_time = time.time()
|
||||
tokens_per_sec = count / (end_time - start_time)
|
||||
yield history, str(format(tokens_per_sec, ".2f")) + " tokens/sec"
|
||||
else:
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, ""
|
||||
|
||||
return history, ""
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
@@ -285,6 +293,7 @@ def llm_chat_api(InputData: dict):
|
||||
UnshardedVicuna,
|
||||
)
|
||||
|
||||
device_id = None
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
@@ -293,6 +302,7 @@ def llm_chat_api(InputData: dict):
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
@@ -303,6 +313,9 @@ def llm_chat_api(InputData: dict):
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=True,
|
||||
load_mlir_from_shark_tank=True,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
# TODO: add role dict for different models
|
||||
@@ -373,32 +386,40 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
print(supported_devices)
|
||||
devices = gr.Dropdown(
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
multiselect=True,
|
||||
# multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp16",
|
||||
value="int8",
|
||||
choices=[
|
||||
"int4",
|
||||
"int8",
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(label="Upload sharding configuration")
|
||||
json_view_button = gr.Button("View as JSON")
|
||||
json_view = gr.JSON(interactive=True)
|
||||
config_file = gr.File(
|
||||
label="Upload sharding configuration", visible=False
|
||||
)
|
||||
json_view_button = gr.Button(label="View as JSON", visible=False)
|
||||
json_view = gr.JSON(interactive=True, visible=False)
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
@@ -425,16 +446,32 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot],
|
||||
inputs=[
|
||||
system_msg,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot],
|
||||
inputs=[
|
||||
system_msg,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
|
||||
@@ -4,6 +4,7 @@ import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from math import ceil
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
@@ -26,6 +27,7 @@ from apps.stable_diffusion.src import (
|
||||
utils,
|
||||
save_output_img,
|
||||
prompt_examples,
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
@@ -62,6 +64,11 @@ def txt2img_inf(
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
use_hiresfix: bool,
|
||||
hiresfix_height: int,
|
||||
hiresfix_width: int,
|
||||
hiresfix_strength: float,
|
||||
resample_type: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -200,6 +207,81 @@ def txt2img_inf(
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
# TODO: allow user to save original image
|
||||
# TODO: add option to let user keep both pipelines loaded, and unload
|
||||
# either at will
|
||||
# TODO: add custom step value slider
|
||||
# TODO: add option to use secondary model for the img2img pass
|
||||
if use_hiresfix is True:
|
||||
new_config_obj = Config(
|
||||
"img2img",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
precision,
|
||||
1,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil="None",
|
||||
ondemand=ondemand,
|
||||
)
|
||||
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(args.scheduler)
|
||||
|
||||
global_obj.set_sd_obj(
|
||||
Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
1,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(args.scheduler)
|
||||
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
out_imgs[0],
|
||||
batch_size,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
ceil(steps / hiresfix_strength),
|
||||
hiresfix_strength,
|
||||
guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil="None",
|
||||
resample_type=resample_type,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
@@ -271,6 +353,11 @@ def txt2img_api(
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
use_hiresfix=False,
|
||||
hiresfix_height=512,
|
||||
hiresfix_width=512,
|
||||
hiresfix_strength=0.6,
|
||||
resample_type="Nearest Neighbor",
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
@@ -460,6 +547,49 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
use_hiresfix = gr.Checkbox(
|
||||
value=args.use_hiresfix,
|
||||
label="Use Hires Fix",
|
||||
interactive=True,
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
choices=[
|
||||
"Lanczos",
|
||||
"Nearest Neighbor",
|
||||
"Bilinear",
|
||||
"Bicubic",
|
||||
"Adaptive",
|
||||
"Antialias",
|
||||
"Box",
|
||||
"Affine",
|
||||
"Cubic",
|
||||
],
|
||||
label="Resample Type",
|
||||
)
|
||||
hiresfix_height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_height,
|
||||
step=8,
|
||||
label="Hires Fix Height",
|
||||
)
|
||||
hiresfix_width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_width,
|
||||
step=8,
|
||||
label="Hires Fix Width",
|
||||
)
|
||||
hiresfix_strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.hiresfix_strength,
|
||||
step=0.01,
|
||||
label="Hires Fix Denoising Strength",
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
@@ -495,16 +625,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
@@ -530,6 +650,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
show_label=False,
|
||||
)
|
||||
txt2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -565,6 +697,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
use_hiresfix,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
hiresfix_strength,
|
||||
resample_type,
|
||||
],
|
||||
outputs=[txt2img_gallery, std_output, txt2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
|
||||
@@ -25,7 +25,7 @@ class Config:
|
||||
device: str
|
||||
use_lora: str
|
||||
use_stencil: str
|
||||
ondemand: str
|
||||
ondemand: str # should this be expecting a bool instead?
|
||||
|
||||
|
||||
custom_model_filetypes = (
|
||||
|
||||
@@ -24,13 +24,13 @@ def get_image(url, local_filename):
|
||||
shutil.copyfileobj(res.raw, f)
|
||||
|
||||
|
||||
def compare_images(new_filename, golden_filename):
|
||||
def compare_images(new_filename, golden_filename, upload=False):
|
||||
new = np.array(Image.open(new_filename)) / 255.0
|
||||
golden = np.array(Image.open(golden_filename)) / 255.0
|
||||
diff = np.abs(new - golden)
|
||||
mean = np.mean(diff)
|
||||
if mean > 0.1:
|
||||
if os.name != "nt":
|
||||
if os.name != "nt" and upload == True:
|
||||
subprocess.run(
|
||||
[
|
||||
"gsutil",
|
||||
@@ -39,7 +39,7 @@ def compare_images(new_filename, golden_filename):
|
||||
"gs://shark_tank/testdata/builder/",
|
||||
]
|
||||
)
|
||||
raise SystemExit("new and golden not close")
|
||||
raise AssertionError("new and golden not close")
|
||||
else:
|
||||
print("SUCCESS")
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
source $GITHUB_WORKSPACE/shark.venv/bin/activate
|
||||
python tank/generate_sharktank.py
|
||||
python build_tools/stable_diffusion_testing.py --gen
|
||||
|
||||
@@ -63,7 +63,14 @@ def get_inpaint_inputs():
|
||||
open("./test_images/inputs/mask.png", "wb").write(mask.content)
|
||||
|
||||
|
||||
def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
def test_loop(
|
||||
device="vulkan",
|
||||
beta=False,
|
||||
extra_flags=[],
|
||||
upload_bool=True,
|
||||
exit_on_fail=True,
|
||||
do_gen=False,
|
||||
):
|
||||
# Get golden values from tank
|
||||
shutil.rmtree("./test_images", ignore_errors=True)
|
||||
model_metrics = []
|
||||
@@ -81,6 +88,8 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
if beta:
|
||||
extra_flags.append("--beta_models=True")
|
||||
extra_flags.append("--no-progress_bar")
|
||||
if do_gen:
|
||||
extra_flags.append("--import_debug")
|
||||
to_skip = [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
@@ -181,7 +190,14 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
"./test_images/golden/" + model_name + "/*.png"
|
||||
)
|
||||
golden_file = glob(golden_path)[0]
|
||||
compare_images(test_file, golden_file)
|
||||
try:
|
||||
compare_images(
|
||||
test_file, golden_file, upload=upload_bool
|
||||
)
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
if exit_on_fail == True:
|
||||
raise
|
||||
else:
|
||||
print(command)
|
||||
print("failed to generate image for this configuration")
|
||||
@@ -200,6 +216,9 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
extra_flags.remove(
|
||||
"--iree_vulkan_target_triple=rdna2-unknown-windows"
|
||||
)
|
||||
if do_gen:
|
||||
prepare_artifacts()
|
||||
|
||||
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
|
||||
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
|
||||
f.write(header)
|
||||
@@ -218,15 +237,49 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
f.write(";".join(output) + "\n")
|
||||
|
||||
|
||||
def prepare_artifacts():
|
||||
gen_path = os.path.join(os.getcwd(), "gen_shark_tank")
|
||||
if not os.path.isdir(gen_path):
|
||||
os.mkdir(gen_path)
|
||||
for dirname in os.listdir(os.getcwd()):
|
||||
for modelname in ["clip", "unet", "vae"]:
|
||||
if modelname in dirname and "vmfb" not in dirname:
|
||||
if not os.path.isdir(os.path.join(gen_path, dirname)):
|
||||
shutil.move(os.path.join(os.getcwd(), dirname), gen_path)
|
||||
print(f"Moved dir: {dirname} to {gen_path}.")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("-d", "--device", default="vulkan")
|
||||
parser.add_argument(
|
||||
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
|
||||
)
|
||||
|
||||
parser.add_argument("-e", "--extra_args", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"-u", "--upload", action=argparse.BooleanOptionalAction, default=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--gen", action=argparse.BooleanOptionalAction, default=False
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
test_loop(args.device, args.beta, [])
|
||||
extra_args = []
|
||||
if args.extra_args:
|
||||
for arg in args.extra_args.split(","):
|
||||
extra_args.append(arg)
|
||||
test_loop(
|
||||
args.device,
|
||||
args.beta,
|
||||
extra_args,
|
||||
args.upload,
|
||||
args.exit_on_fail,
|
||||
args.gen,
|
||||
)
|
||||
if args.gen:
|
||||
prepare_artifacts()
|
||||
|
||||
@@ -27,7 +27,7 @@ include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
iree
|
||||
GIT_REPOSITORY https://github.com/nod-ai/shark-runtime.git
|
||||
GIT_REPOSITORY https://github.com/nod-ai/srt.git
|
||||
GIT_TAG shark
|
||||
GIT_SUBMODULES_RECURSE OFF
|
||||
GIT_SHALLOW OFF
|
||||
|
||||
@@ -40,7 +40,7 @@ cmake --build build/
|
||||
*Prepare the model*
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
```
|
||||
*Prepare the input*
|
||||
|
||||
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
|
||||
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
|
||||
```
|
||||
VAE and Autoencoder are also available
|
||||
```bash
|
||||
# VAE
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux vae.mlir -o vae.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
|
||||
|
||||
# CLIP Autoencoder
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
|
||||
```
|
||||
|
||||
@@ -55,7 +55,7 @@ The command line for compilation will start something like this, where the `-` n
|
||||
The `-o output_filename.vmfb` flag can be used to specify the location to save the compiled vmfb. Note that a dump of the
|
||||
dispatches that can be compiled + run in isolation can be generated by adding `--iree-hal-dump-executable-benchmarks-to=/some/directory`. Say, if they are in the `benchmarks` directory, the following compile/run commands would work for Vulkan on RDNA3.
|
||||
```
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
|
||||
|
||||
iree-benchmark-module --module=benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb --function=forward --device=vulkan
|
||||
```
|
||||
@@ -63,8 +63,8 @@ Where `${NUM}` is the dispatch number that you want to benchmark/profile in isol
|
||||
|
||||
### Enabling Tracy for Vulkan profiling
|
||||
|
||||
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime builds an
|
||||
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SHARK-Runtime/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
|
||||
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime (SRT) builds an
|
||||
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SRT/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
|
||||
```
|
||||
$env:IREE_ENABLE_RUNTIME_TRACING="ON"
|
||||
```
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
|
||||
project(sharkbackend LANGUAGES C CXX)
|
||||
|
||||
#
|
||||
# Options
|
||||
#
|
||||
|
||||
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
|
||||
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
|
||||
|
||||
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
|
||||
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
|
||||
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Dependencies
|
||||
#
|
||||
# FetchContent requires us to include the transitive closure of all
|
||||
# repos that we depend on so that we can override the tags.
|
||||
#
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-common
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
|
||||
GIT_TAG ${TRITON_COMMON_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-core
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
|
||||
GIT_TAG ${TRITON_CORE_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-backend
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
|
||||
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
|
||||
|
||||
#
|
||||
# The backend must be built into a shared library. Use an ldscript to
|
||||
# hide all symbols except for the TRITONBACKEND API.
|
||||
#
|
||||
configure_file(src/libtriton_dshark.ldscript libtriton_dshark.ldscript COPYONLY)
|
||||
|
||||
add_library(
|
||||
triton-dshark-backend SHARED
|
||||
src/dshark.cc
|
||||
#src/dshark_driver_module.c
|
||||
)
|
||||
|
||||
add_library(
|
||||
SharkBackend::triton-dshark-backend ALIAS triton-dshark-backend
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
triton-dshark-backend
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
|
||||
|
||||
add_subdirectory(thirdparty/shark-runtime EXCLUDE_FROM_ALL)
|
||||
|
||||
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
|
||||
iree_hal_hal
|
||||
iree_hal_cuda_cuda
|
||||
iree_hal_cuda_registration_registration
|
||||
iree_hal_vmvx_registration_registration
|
||||
iree_hal_dylib_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_hal_local_loaders_system_library_loader
|
||||
iree_hal_local_loaders_vmvx_module_loader
|
||||
)
|
||||
|
||||
target_compile_features(triton-dshark-backend PRIVATE cxx_std_11)
|
||||
|
||||
|
||||
target_link_libraries(
|
||||
triton-dshark-backend
|
||||
PRIVATE
|
||||
triton-core-serverapi # from repo-core
|
||||
triton-core-backendapi # from repo-core
|
||||
triton-core-serverstub # from repo-core
|
||||
triton-backend-utils # from repo-backend
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
set_target_properties(
|
||||
triton-dshark-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_dshark
|
||||
)
|
||||
else()
|
||||
set_target_properties(
|
||||
triton-dshark-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_dshark
|
||||
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_dshark.ldscript
|
||||
LINK_FLAGS "-Wl,--version-script libtriton_dshark.ldscript"
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
#
|
||||
# Install
|
||||
#
|
||||
include(GNUInstallDirs)
|
||||
set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/SharkBackend)
|
||||
|
||||
install(
|
||||
TARGETS
|
||||
triton-dshark-backend
|
||||
EXPORT
|
||||
triton-dshark-backend-targets
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
|
||||
)
|
||||
|
||||
install(
|
||||
EXPORT
|
||||
triton-dshark-backend-targets
|
||||
FILE
|
||||
SharkBackendTargets.cmake
|
||||
NAMESPACE
|
||||
SharkBackend::
|
||||
DESTINATION
|
||||
${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
include(CMakePackageConfigHelpers)
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_LIST_DIR}/cmake/SharkBackendConfig.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
|
||||
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
install(
|
||||
FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
|
||||
DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
#
|
||||
# Export from build tree
|
||||
#
|
||||
export(
|
||||
EXPORT triton-dshark-backend-targets
|
||||
FILE ${CMAKE_CURRENT_BINARY_DIR}/SharkBackendTargets.cmake
|
||||
NAMESPACE SharkBackend::
|
||||
)
|
||||
|
||||
export(PACKAGE SharkBackend)
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
# SHARK Triton Backend
|
||||
|
||||
The triton backend for shark.
|
||||
|
||||
# Build
|
||||
|
||||
Install SHARK
|
||||
|
||||
```
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
# skip above step if dshark is already installed
|
||||
cd SHARK/inference
|
||||
```
|
||||
|
||||
install dependancies
|
||||
|
||||
```
|
||||
apt-get install patchelf rapidjson-dev python3-dev
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
update the submodules of iree
|
||||
|
||||
```
|
||||
cd thirdparty/shark-runtime
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
Next, make the backend and install it
|
||||
|
||||
```
|
||||
cd ../..
|
||||
mkdir build && cd build
|
||||
cmake -DTRITON_ENABLE_GPU=ON \
|
||||
-DIREE_HAL_DRIVER_CUDA=ON \
|
||||
-DIREE_TARGET_BACKEND_CUDA=ON \
|
||||
-DMLIR_ENABLE_CUDA_RUNNER=ON \
|
||||
-DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \
|
||||
-DTRITON_BACKEND_REPO_TAG=r22.02 \
|
||||
-DTRITON_CORE_REPO_TAG=r22.02 \
|
||||
-DTRITON_COMMON_REPO_TAG=r22.02 ..
|
||||
make install
|
||||
```
|
||||
|
||||
# Incorporating into Triton
|
||||
|
||||
There are much more in depth explenations for the following steps in triton's documentation:
|
||||
https://github.com/triton-inference-server/server/blob/main/docs/compose.md#triton-with-unsupported-and-custom-backends
|
||||
|
||||
There should be a file at /build/install/backends/dshark/libtriton_dshark.so. You will need to copy it into your triton server image.
|
||||
More documentation is in the link above, but to create the docker image, you need to run the compose.py command in the triton-backend server repo
|
||||
|
||||
|
||||
To first build your image, clone the tritonserver repo.
|
||||
|
||||
```
|
||||
git clone https://github.com/triton-inference-server/server.git
|
||||
```
|
||||
|
||||
then run `compose.py` to build a docker compose file
|
||||
```
|
||||
cd server
|
||||
python3 compose.py --repoagent checksum --dry-run
|
||||
```
|
||||
|
||||
Because dshark is a third party backend, you will need to manually modify the `Dockerfile.compose` to include the dshark backend. To do this, in the Dockerfile.compose file produced, copy this line.
|
||||
the dshark backend will be located in the build folder from earlier under `/build/install/backends`
|
||||
|
||||
```
|
||||
COPY /path/to/build/install/backends/dshark /opt/tritonserver/backends/dshark
|
||||
```
|
||||
|
||||
Next run
|
||||
```
|
||||
docker build -t tritonserver_custom -f Dockerfile.compose .
|
||||
docker run -it --gpus=1 --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
|
||||
```
|
||||
|
||||
where `path/to/model_repos` is where you are storing the models you want to run
|
||||
|
||||
if your not using gpus, omit `--gpus=1`
|
||||
|
||||
```
|
||||
docker run -it --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
|
||||
```
|
||||
|
||||
# Setting up a model
|
||||
|
||||
to include a model in your backend, add a directory with your model name to your model repository directory. examples of models can be seen here: https://github.com/triton-inference-server/backend/tree/main/examples/model_repos/minimal_models
|
||||
|
||||
make sure to adjust the input correctly in the config.pbtxt file, and save a vmfb file under 1/model.vmfb
|
||||
|
||||
# CUDA
|
||||
|
||||
if you're having issues with cuda, make sure your correct drivers are installed, and that `nvidia-smi` works, and also make sure that the nvcc compiler is on the path.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include(CMakeFindDependencyMacro)
|
||||
|
||||
get_filename_component(
|
||||
SHARKBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${SHARKBACKEND_CMAKE_DIR})
|
||||
|
||||
if(NOT TARGET SharkBackend::triton-dshark-backend)
|
||||
include("${SHARKBACKEND_CMAKE_DIR}/SharkBackendTargets.cmake")
|
||||
endif()
|
||||
|
||||
set(SHARKBACKEND_LIBRARIES SharkBackend::triton-dshark-backend)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,30 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
{
|
||||
global:
|
||||
TRITONBACKEND_*;
|
||||
local: *;
|
||||
};
|
||||
1
inference/thirdparty/shark-runtime
vendored
1
inference/thirdparty/shark-runtime
vendored
Submodule inference/thirdparty/shark-runtime deleted from 7b82d90c72
@@ -6,15 +6,15 @@ from distutils.sysconfig import get_python_lib
|
||||
import fileinput
|
||||
from pathlib import Path
|
||||
|
||||
# Temorary workaround for transformers/__init__.py.
|
||||
path_to_tranformers_hook = Path(
|
||||
# Temporary workaround for transformers/__init__.py.
|
||||
path_to_transformers_hook = Path(
|
||||
get_python_lib()
|
||||
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
|
||||
)
|
||||
if path_to_tranformers_hook.is_file():
|
||||
if path_to_transformers_hook.is_file():
|
||||
pass
|
||||
else:
|
||||
with open(path_to_tranformers_hook, "w") as f:
|
||||
with open(path_to_transformers_hook, "w") as f:
|
||||
f.write("module_collection_mode = 'pyz+py'")
|
||||
|
||||
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
|
||||
|
||||
@@ -5,7 +5,7 @@ requires = [
|
||||
"packaging",
|
||||
|
||||
"numpy>=1.22.4",
|
||||
"torch-mlir>=20221021.633",
|
||||
"torch-mlir>=20230620.875",
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
]
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
numpy>1.22.4
|
||||
pytorch-triton
|
||||
torchvision==0.16.0.dev20230322
|
||||
torchvision
|
||||
tabulate
|
||||
|
||||
tqdm
|
||||
@@ -15,8 +15,8 @@ iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow>2.11
|
||||
keras
|
||||
tf-nightly
|
||||
keras-nightly
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
|
||||
setuptools
|
||||
wheel
|
||||
|
||||
@@ -15,13 +18,14 @@ Pillow
|
||||
parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
tokenizers==0.13.3
|
||||
transformers
|
||||
diffusers
|
||||
#accelerate is now required for diffusers import from ckpt.
|
||||
accelerate
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
gradio==3.44.3
|
||||
altair
|
||||
omegaconf
|
||||
# 0.3.2 doesn't have binaries for arm64
|
||||
|
||||
@@ -90,8 +90,8 @@ python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
Write-Host "Build and installation completed successfully"
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
|
||||
@@ -103,7 +103,7 @@ else
|
||||
fi
|
||||
if [[ -z "${USE_IREE}" ]]; then
|
||||
rm .use-iree
|
||||
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
|
||||
RUNTIME="https://nod-ai.github.io/SRT/pip-release-links.html"
|
||||
else
|
||||
touch ./.use-iree
|
||||
RUNTIME="https://openxla.github.io/iree/pip-release-links.html"
|
||||
@@ -128,16 +128,15 @@ if [[ ! -z "${IMPORTER}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/cpu/
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
|
||||
T_VER=$($PYTHON -m pip show torch | grep Version)
|
||||
TORCH_VERSION=${T_VER:9:17}
|
||||
T_VER_MIN=${T_VER:14:12}
|
||||
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
|
||||
TV_VERSION=${TV_VER:9:18}
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
|
||||
TV_VER_MAJ=${TV_VER:9:6}
|
||||
$PYTHON -m pip uninstall -y torchvision
|
||||
$PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu118."
|
||||
else
|
||||
@@ -145,14 +144,8 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ ! -z "${ONNX}" ]]; then
|
||||
echo "${Yellow}Installing ONNX and onnxruntime for benchmarks..."
|
||||
$PYTHON -m pip install onnx onnxruntime psutil
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully installed ONNX and ONNX runtime."
|
||||
else
|
||||
echo "Could not install ONNX." >&2
|
||||
fi
|
||||
if [[ -z "${NO_BREVITAS}" ]]; then
|
||||
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev
|
||||
fi
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
|
||||
|
||||
@@ -43,9 +43,7 @@ if __name__ == "__main__":
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module = SharkInference(minilm_mlir)
|
||||
shark_module.compile()
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
|
||||
@@ -52,6 +52,8 @@ def iree_device_map(device):
|
||||
)
|
||||
if len(uri_parts) == 1:
|
||||
return iree_driver
|
||||
elif "rocm" in uri_parts:
|
||||
return "rocm"
|
||||
else:
|
||||
return f"{iree_driver}://{uri_parts[1]}"
|
||||
|
||||
@@ -63,7 +65,6 @@ def get_supported_device_list():
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"AMD-AIE": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
@@ -82,7 +83,6 @@ def iree_target_map(device):
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cpu-task": "llvm-cpu",
|
||||
"AMD-AIE": "llvm-cpu",
|
||||
"cpu-sync": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
@@ -121,7 +121,10 @@ def check_device_drivers(device):
|
||||
return False
|
||||
elif device == "rocm":
|
||||
try:
|
||||
subprocess.check_output("rocminfo")
|
||||
if sys.platform == "win32":
|
||||
subprocess.check_output("hipinfo")
|
||||
else:
|
||||
subprocess.check_output("rocminfo")
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
|
||||
from shark.iree_utils._common import run_cmd, iree_device_map
|
||||
from shark.iree_utils.cpu_utils import get_cpu_count
|
||||
import numpy as np
|
||||
@@ -62,16 +61,12 @@ def build_benchmark_args(
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = benchmark_module.__path__[0]
|
||||
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
|
||||
if platform.system() == "Windows":
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module.exe"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
|
||||
time_extractor = None
|
||||
else:
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
@@ -106,15 +101,13 @@ def build_benchmark_args_non_tensor_input(
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = benchmark_module.__path__[0]
|
||||
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
|
||||
if platform.system() == "Windows":
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module.exe"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
|
||||
time_extractor = None
|
||||
else:
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
if function_name:
|
||||
@@ -139,7 +132,7 @@ def run_benchmark_module(benchmark_cl):
|
||||
benchmark_path = benchmark_cl[0]
|
||||
assert os.path.exists(
|
||||
benchmark_path
|
||||
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
|
||||
), "Cannot find iree_benchmark_module, Please contact SHARK maintainer on discord."
|
||||
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
|
||||
try:
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
|
||||
@@ -46,7 +46,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
|
||||
data_tiling_flag = ["--iree-opt-data-tiling"]
|
||||
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
|
||||
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
|
||||
|
||||
@@ -84,7 +84,7 @@ def get_iree_frontend_args(frontend):
|
||||
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
return [
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-flow-demote-i64-to-i32",
|
||||
"--iree-input-demote-i64-to-i32",
|
||||
]
|
||||
else:
|
||||
# Frontend not found.
|
||||
@@ -92,13 +92,27 @@ def get_iree_frontend_args(frontend):
|
||||
|
||||
|
||||
# Common args to be used given any frontend or device.
|
||||
def get_iree_common_args():
|
||||
return [
|
||||
"--iree-stream-resource-index-bits=64",
|
||||
"--iree-vm-target-index-bits=64",
|
||||
def get_iree_common_args(debug=False):
|
||||
common_args = [
|
||||
"--iree-stream-resource-max-allocation-size=4294967295",
|
||||
"--iree-vm-bytecode-module-strip-source-map=true",
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
]
|
||||
if debug == True:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=false",
|
||||
"--verify=true",
|
||||
]
|
||||
)
|
||||
else:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=true",
|
||||
"--verify=false",
|
||||
]
|
||||
)
|
||||
return common_args
|
||||
|
||||
|
||||
# Args that are suitable only for certain models or groups of models.
|
||||
@@ -277,14 +291,16 @@ def compile_module_to_flatbuffer(
|
||||
model_config_path,
|
||||
extra_args,
|
||||
model_name="None",
|
||||
debug=False,
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args()
|
||||
args += get_iree_common_args(debug=debug)
|
||||
args += get_model_specific_args()
|
||||
args += extra_args
|
||||
args += shark_args.additional_compile_args
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "auto"
|
||||
@@ -342,7 +358,8 @@ def load_vmfb_using_mmap(
|
||||
flatbuffer_blob_or_path, device: str, device_idx: int = None
|
||||
):
|
||||
print(f"Loading module {flatbuffer_blob_or_path}...")
|
||||
|
||||
if "rocm" in device:
|
||||
device = "rocm"
|
||||
with DetailLogger(timeout=2.5) as dl:
|
||||
# First get configs.
|
||||
if device_idx is not None:
|
||||
@@ -409,10 +426,11 @@ def get_iree_compiled_module(
|
||||
extra_args: list = [],
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, model_config_path, extra_args
|
||||
module, device, frontend, model_config_path, extra_args, debug
|
||||
)
|
||||
temp_file_to_unlink = None
|
||||
# TODO: Currently mmap=True control flow path has been switched off for mmap.
|
||||
@@ -468,10 +486,11 @@ def export_iree_module_to_vmfb(
|
||||
model_config_path: str = None,
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
debug: bool = False,
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, mlir_dialect, model_config_path, extra_args
|
||||
module, device, mlir_dialect, model_config_path, extra_args, debug
|
||||
)
|
||||
if module_name is None:
|
||||
device_name = (
|
||||
@@ -479,9 +498,9 @@ def export_iree_module_to_vmfb(
|
||||
)
|
||||
module_name = f"{mlir_dialect}_{device_name}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
with open(filename, "wb") as f:
|
||||
f.write(flatbuffer_blob)
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
return filename
|
||||
|
||||
|
||||
@@ -547,9 +566,15 @@ def get_results(
|
||||
def get_iree_runtime_config(device):
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
if device == "metal" and shark_args.device_allocator == "caching":
|
||||
print(
|
||||
"[WARNING] metal devices can not have a `caching` allocator."
|
||||
"\nUsing default allocator `None`"
|
||||
)
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
allocators=shark_args.device_allocator,
|
||||
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
|
||||
allocators=shark_args.device_allocator if device != "metal" else None,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
return config
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import functools
|
||||
import iree.runtime as ireert
|
||||
import ctypes
|
||||
import sys
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
@@ -42,21 +43,51 @@ def get_iree_gpu_args():
|
||||
@functools.cache
|
||||
def get_iree_rocm_args():
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
# get arch from rocminfo.
|
||||
# get arch from hipinfo.
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
rocm_arch = re.match(
|
||||
r".*(gfx\w+)",
|
||||
subprocess.check_output(
|
||||
"rocminfo | grep -i 'gfx'", shell=True, text=True
|
||||
),
|
||||
).group(1)
|
||||
print(f"Found rocm arch {rocm_arch}...")
|
||||
if sys.platform == "win32":
|
||||
if "HIP_PATH" in os.environ:
|
||||
rocm_path = os.environ["HIP_PATH"]
|
||||
print(f"Found a ROCm installation at {rocm_path}.")
|
||||
else:
|
||||
print("Failed to find ROCM_PATH. Defaulting to C:\\AMD\\ROCM\\5.5")
|
||||
rocm_path = "C:\\AMD\\ROCM\\5.5"
|
||||
else:
|
||||
if "ROCM_PATH" in os.environ:
|
||||
rocm_path = os.environ["ROCM_PATH"]
|
||||
print(f"Found a ROCm installation at {rocm_path}.")
|
||||
else:
|
||||
print("Failed to find ROCM_PATH. Defaulting to /opt/rocm")
|
||||
rocm_path = "/opt/rocm/"
|
||||
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
rocm_arch = re.search(
|
||||
r"gfx\d{3,}",
|
||||
subprocess.check_output("hipinfo", shell=True, text=True),
|
||||
).group(0)
|
||||
else:
|
||||
rocm_arch = re.match(
|
||||
r".*(gfx\w+)",
|
||||
subprocess.check_output(
|
||||
"rocminfo | grep -i 'gfx'", shell=True, text=True
|
||||
),
|
||||
).group(1)
|
||||
print(f"Found rocm arch {rocm_arch}...")
|
||||
except:
|
||||
print(
|
||||
"Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100."
|
||||
)
|
||||
rocm_arch = "gfx1100"
|
||||
|
||||
bc_path = os.path.join(rocm_path, "amdgcn", "bitcode")
|
||||
return [
|
||||
f"--iree-rocm-target-chip={rocm_arch}",
|
||||
"--iree-rocm-link-bc=true",
|
||||
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
|
||||
f"--iree-rocm-bc-dir={bc_path}",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -57,11 +57,8 @@ def get_version(triple):
|
||||
@functools.cache
|
||||
def get_extensions(triple):
|
||||
def make_ext_list(ext_list):
|
||||
res = ""
|
||||
for e in ext_list:
|
||||
res += e + ", "
|
||||
res = f"[{res[:-2]}]"
|
||||
return res
|
||||
res = ", ".join(ext_list)
|
||||
return f"[{res}]"
|
||||
|
||||
arch, product, os = triple
|
||||
if arch == "m1":
|
||||
@@ -119,7 +116,7 @@ def get_extensions(triple):
|
||||
]
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("VK_NV_cooperative_matrix")
|
||||
ext.append("VK_KHR_cooperative_matrix")
|
||||
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
|
||||
ext.append("VK_KHR_shader_integer_dot_product")
|
||||
return make_ext_list(ext_list=ext)
|
||||
@@ -247,7 +244,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>"
|
||||
]
|
||||
|
||||
if product == "rx5700xt":
|
||||
@@ -468,9 +465,9 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
]
|
||||
|
||||
elif arch == "adreno":
|
||||
@@ -531,7 +528,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cmc = ""
|
||||
for case in v:
|
||||
cmc += f"#vk.coop_matrix_props<{case}>, "
|
||||
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
|
||||
res += f"cooperativeMatrixPropertiesKHR = [{cmc[:-2]}], "
|
||||
else:
|
||||
res += f"{k} = {get_comma_sep_str(v)}, "
|
||||
else:
|
||||
|
||||
@@ -23,11 +23,19 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_all_vulkan_devices():
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver("vulkan")
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return [d["name"] for d in device_list_src]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
|
||||
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
@@ -178,9 +186,7 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
@functools.cache
|
||||
def get_iree_vulkan_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={shark_args.vulkan_large_heap_block_size}",
|
||||
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
|
||||
f"--vulkan_vma_allocator={'true' if shark_args.vulkan_vma_allocator else 'false'}",
|
||||
]
|
||||
return vulkan_runtime_flags
|
||||
|
||||
|
||||
@@ -14,8 +14,21 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
|
||||
class SplitStrToListAction(argparse.Action):
|
||||
def __init__(self, option_strings, dest, *args, **kwargs):
|
||||
super(SplitStrToListAction, self).__init__(
|
||||
option_strings=option_strings, dest=dest, *args, **kwargs
|
||||
)
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
del parser, option_string
|
||||
setattr(namespace, self.dest, shlex.split(values[0]))
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="SHARK runner.")
|
||||
|
||||
parser.add_argument(
|
||||
@@ -24,6 +37,13 @@ parser.add_argument(
|
||||
default="cpu",
|
||||
help="Device on which shark_runner runs. options are cpu, cuda, and vulkan",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional_compile_args",
|
||||
default=list(),
|
||||
nargs=1,
|
||||
action=SplitStrToListAction,
|
||||
help="Additional arguments to pass to the compiler. These are appended as the last arguments.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_tf32",
|
||||
type=bool,
|
||||
@@ -133,13 +153,6 @@ parser.add_argument(
|
||||
help="Profiles vulkan device and collects the .rdc info.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="2073741824",
|
||||
help="Flag for setting VMA preferredLargeHeapBlockSize for "
|
||||
"vulkan device, default is 4G.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
@@ -147,11 +160,4 @@ parser.add_argument(
|
||||
help="Flag for disabling vulkan validation layers when benchmarking.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_vma_allocator",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling / disabling Vulkan VMA Allocator.",
|
||||
)
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
|
||||
@@ -13,7 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.iree_utils.compile_utils import export_iree_module_to_vmfb
|
||||
from shark.iree_utils.compile_utils import (
|
||||
export_iree_module_to_vmfb,
|
||||
load_flatbuffer,
|
||||
get_iree_runtime_config,
|
||||
)
|
||||
from shark.iree_utils.benchmark_utils import (
|
||||
build_benchmark_args,
|
||||
run_benchmark_module,
|
||||
@@ -79,22 +83,31 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
self.import_args = {}
|
||||
self.temp_file_to_unlink = None
|
||||
SharkRunner.__init__(
|
||||
self,
|
||||
mlir_module,
|
||||
device,
|
||||
self.mlir_dialect,
|
||||
self.extra_args,
|
||||
compile_vmfb=True,
|
||||
compile_vmfb=False,
|
||||
)
|
||||
if self.vmfb_file == None:
|
||||
self.vmfb_file = export_iree_module_to_vmfb(
|
||||
mlir_module,
|
||||
device,
|
||||
".",
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
self.vmfb_file = export_iree_module_to_vmfb(
|
||||
mlir_module,
|
||||
device,
|
||||
".",
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
params = load_flatbuffer(
|
||||
self.vmfb_file,
|
||||
device,
|
||||
mmap=True,
|
||||
)
|
||||
self.iree_compilation_module = params["vmfb"]
|
||||
self.iree_config = params["config"]
|
||||
self.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
|
||||
def setup_cl(self, input_tensors):
|
||||
self.benchmark_cl = build_benchmark_args(
|
||||
@@ -111,42 +124,41 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
elif self.mlir_dialect in ["mhlo", "tf"]:
|
||||
return self.benchmark_tf(modelname)
|
||||
|
||||
def benchmark_torch(self, modelname):
|
||||
def benchmark_torch(self, modelname, device="cpu"):
|
||||
import torch
|
||||
from tank.model_utils import get_torch_model
|
||||
|
||||
if self.device == "cuda":
|
||||
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||
if self.enable_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
# TODO: Pass this as an arg. currently the best way is to setup with BENCHMARK=1 if we want to use torch+cuda, else use cpu.
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if device == "cuda":
|
||||
torch.set_default_device("cuda:0")
|
||||
# if self.enable_tf32:
|
||||
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
torch_device = torch.device(
|
||||
"cuda:0" if self.device == "cuda" else "cpu"
|
||||
)
|
||||
torch.set_default_dtype(torch.float32)
|
||||
torch.set_default_device("cpu")
|
||||
torch_device = torch.device("cuda:0" if device == "cuda" else "cpu")
|
||||
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
|
||||
frontend_model = HFmodel.model
|
||||
frontend_model.to(torch_device)
|
||||
input.to(torch_device)
|
||||
|
||||
# TODO: re-enable as soon as pytorch CUDA context issues are resolved
|
||||
try:
|
||||
frontend_model = torch.compile(
|
||||
frontend_model, mode="max-autotune", backend="inductor"
|
||||
)
|
||||
except RuntimeError:
|
||||
frontend_model = HFmodel.model
|
||||
if device == "cuda":
|
||||
frontend_model.cuda()
|
||||
input.to(torch.device("cuda:0"))
|
||||
print(input)
|
||||
else:
|
||||
frontend_model.cpu()
|
||||
input.cpu()
|
||||
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(input)
|
||||
|
||||
if self.device == "cuda":
|
||||
if device == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(input)
|
||||
end = time.time()
|
||||
if self.device == "cuda":
|
||||
if device == "cuda":
|
||||
stats = torch.cuda.memory_stats()
|
||||
device_peak_b = stats["allocated_bytes.all.peak"]
|
||||
frontend_model.to(torch.device("cpu"))
|
||||
@@ -158,7 +170,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
print(
|
||||
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
if self.device == "cuda":
|
||||
if device == "cuda":
|
||||
# Set device to CPU so we don't run into segfaults exiting pytest subprocesses.
|
||||
torch_device = torch.device("cpu")
|
||||
return [
|
||||
|
||||
@@ -11,14 +11,8 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
lhs: List[int],
|
||||
rhs: List[int],
|
||||
rhs_scale: List[int],
|
||||
rhs_zero_point: List[int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> List[int]:
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
@@ -27,30 +21,21 @@ def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(
|
||||
lhs_rank_dtype: Tuple[int, int],
|
||||
rhs_rank_dtype: Tuple[int, int],
|
||||
rhs_scale_rank_dtype: Tuple[int, int],
|
||||
rhs_zero_point_rank_dtype: Tuple[int, int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> int:
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
|
||||
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
|
||||
) -> None:
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics,
|
||||
]
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
@@ -122,7 +107,7 @@ def compile_int_precision(
|
||||
torchscript_module,
|
||||
inputs,
|
||||
output_type="torch",
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
@@ -130,7 +115,7 @@ def compile_int_precision(
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
@@ -138,7 +138,7 @@ if __name__ == "__main__":
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna,
|
||||
SecondVicuna7B,
|
||||
CombinedModel,
|
||||
)
|
||||
|
||||
|
||||
@@ -509,22 +509,6 @@ def import_with_fx(
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
block_quant_layer_level_manager,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
brevitas_layer_export_mode,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
LinearWeightBlockQuantHandlerFwd,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
matmul_rhs_group_quant_placeholder,
|
||||
)
|
||||
from brevitas.backport.fx.experimental.proxy_tensor import (
|
||||
make_fx as brevitas_make_fx,
|
||||
)
|
||||
|
||||
golden_values = None
|
||||
if debug:
|
||||
@@ -596,8 +580,30 @@ def import_with_fx(
|
||||
torch.ops.aten.native_layer_norm,
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
||||
torch.ops.aten.index_add,
|
||||
torch.ops.aten.index_add_,
|
||||
]
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
block_quant_layer_level_manager,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
brevitas_layer_export_mode,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
LinearWeightBlockQuantHandlerFwd,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
replace_call_fn_target,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
matmul_rhs_group_quant_placeholder,
|
||||
)
|
||||
from brevitas.backport.fx.experimental.proxy_tensor import (
|
||||
make_fx as brevitas_make_fx,
|
||||
)
|
||||
|
||||
export_context_manager = brevitas_layer_export_mode
|
||||
export_class = block_quant_layer_level_manager(
|
||||
export_handlers=[LinearWeightBlockQuantHandlerFwd]
|
||||
@@ -612,7 +618,7 @@ def import_with_fx(
|
||||
replace_call_fn_target(
|
||||
fx_g,
|
||||
src=matmul_rhs_group_quant_placeholder,
|
||||
target=torch.ops.brevitas.matmul_rhs_group_quant,
|
||||
target=torch.ops.quant.matmul_rhs_group_quant,
|
||||
)
|
||||
|
||||
fx_g.recompile()
|
||||
@@ -677,5 +683,5 @@ def import_with_fx(
|
||||
)
|
||||
return mlir_module, func_name
|
||||
|
||||
mlir_module, func_name = mlir_importer.import_mlir()
|
||||
mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type)
|
||||
return mlir_module, func_name
|
||||
|
||||
@@ -141,6 +141,10 @@ class SharkInference:
|
||||
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run(function_name, inputs, send_to_host)
|
||||
|
||||
# forward function.
|
||||
def forward(self, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run("forward", inputs, send_to_host)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
def get_functions_in_module(self):
|
||||
return self.shark_runner.get_functions_in_module()
|
||||
@@ -188,7 +192,9 @@ class SharkInference:
|
||||
|
||||
# TODO: Instead of passing directory and having names decided by the module
|
||||
# , user may want to save the module with manual names.
|
||||
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
|
||||
def save_module(
|
||||
self, dir=os.getcwd(), module_name=None, extra_args=[], debug=False
|
||||
):
|
||||
return export_iree_module_to_vmfb(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
@@ -196,6 +202,7 @@ class SharkInference:
|
||||
self.mlir_dialect,
|
||||
module_name=module_name,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
# load and return the module.
|
||||
|
||||
@@ -69,7 +69,7 @@ class SharkTrainer:
|
||||
self.frontend = frontend
|
||||
|
||||
# Training function is needed in the case of torch_fn.
|
||||
def compile(self, training_fn=None, extra_args=[]):
|
||||
def compile(self, training_fn=None, mlir_type="linalg", extra_args=[]):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
packed_inputs = (
|
||||
dict(self.model.named_parameters()),
|
||||
@@ -77,7 +77,12 @@ class SharkTrainer:
|
||||
tuple(self.input),
|
||||
)
|
||||
mlir_module, func_name = import_with_fx(
|
||||
training_fn, packed_inputs, False, [], training=True
|
||||
training_fn,
|
||||
packed_inputs,
|
||||
False,
|
||||
[],
|
||||
training=True,
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
self.shark_runner = SharkRunner(
|
||||
mlir_module,
|
||||
|
||||
@@ -13,7 +13,6 @@ google/vit-base-patch16-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,
|
||||
microsoft/MiniLM-L12-H384-uncased,stablehlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
|
||||
microsoft/layoutlm-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
microsoft/mpnet-base,stablehlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
|
||||
albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with aten.tanh in torch-mlir",""
|
||||
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
@@ -30,7 +29,7 @@ nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github
|
||||
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
|
||||
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
|
||||
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,False,True,"",""
|
||||
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
|
||||
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
|
||||
efficientnet-v2-s,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
|
||||
|
@@ -85,8 +85,6 @@ if __name__ == "__main__":
|
||||
args = [
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
"--iree-stream-resource-index-bits=64",
|
||||
"--iree-vm-target-index-bits=64",
|
||||
]
|
||||
backend_config = "dylib"
|
||||
# backend = "cuda"
|
||||
|
||||
@@ -1,3 +1,26 @@
|
||||
# Running Different OPT Variants
|
||||
# Run OPT for sentence completion through SHARK
|
||||
|
||||
To run different sizes of OPT, change the string `OPT_MODEL` string in `opt_torch_test.py`. The default is 350m parameters. 66b cases also exist in the file, simply uncomment the test cases.
|
||||
From base SHARK directory, follow instructions to set up a virtual environment with SHARK. (`./setup_venv.sh` or `./setup_venv.ps1`)
|
||||
Then, you may run opt_causallm.py to get a very simple sentence completion application running through SHARK
|
||||
```
|
||||
python opt_causallm.py
|
||||
```
|
||||
|
||||
# Run OPT performance comparison on SHARK vs. PyTorch
|
||||
|
||||
```
|
||||
python opt_perf_comparison.py --max-seq-len=512 --model-name=facebook/opt-1.3b \
|
||||
--platform=shark
|
||||
```
|
||||
Any OPT model from huggingface should work with this script, and you can choose between `--platform=shark` or `--platform=huggingface` to generate benchmarks of OPT inference on SHARK / PyTorch.
|
||||
|
||||
# Run a small suite of OPT models through the benchmark script
|
||||
|
||||
```
|
||||
python opt_perf_comparison_batch.py
|
||||
```
|
||||
This script will run benchmarks from a suite of OPT configurations:
|
||||
- Sequence Lengths: 32, 128, 256, 512
|
||||
- Parameter Counts: 125m, 350m, 1.3b
|
||||
|
||||
note: Most of these scripts are written for use on CPU, as perf comparisons against pytorch can be problematic across platforms otherwise.
|
||||
|
||||
@@ -59,7 +59,7 @@ def create_module(model_name, tokenizer, device):
|
||||
)
|
||||
|
||||
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
|
||||
shark_module.save_module(module_name=vmfb_name)
|
||||
shark_module.save_module(module_name=vmfb_name, debug=False)
|
||||
vmfb_path = vmfb_name + ".vmfb"
|
||||
return vmfb_path
|
||||
|
||||
|
||||
@@ -1,18 +1,46 @@
|
||||
"""
|
||||
Script for comparing OPT model performance between SHARK and Huggingface
|
||||
PyTorch.
|
||||
|
||||
Usage Example:
|
||||
|
||||
python opt_perf_comparison.py --max-seq-len=32 --model-name=facebook/opt-125m \
|
||||
--platform=shark
|
||||
|
||||
python opt_perf_comparison.py --max-seq-len=512 --model-name=facebook/opt-1.3b \
|
||||
--platform=shark
|
||||
|
||||
See parse_args() below for command line argument usage.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import psutil
|
||||
import resource
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
|
||||
MODEL_NAME = "facebook/opt-1.3b"
|
||||
OPT_MODELNAME = "opt-1.3b"
|
||||
OPT_FS_NAME = "opt_1-3b"
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
DEVICE = "cpu"
|
||||
PLATFORM_SHARK = "shark"
|
||||
PLATFORM_HUGGINGFACE = "huggingface"
|
||||
|
||||
# Dict keys for reports.
|
||||
REPORT_PLATFORM = "platform"
|
||||
REPORT_MODEL_NAME = "model"
|
||||
REPORT_MAX_SEQ_LEN = "max_seq_len"
|
||||
REPORT_LOAD_TIME = "load_time_sec"
|
||||
REPORT_RUN_TIME = "run_time_sec"
|
||||
REPORT_LOAD_PHYSICAL_MEMORY_MB = "load_physical_MB"
|
||||
REPORT_LOAD_VIRTUAL_MEMORY_MB = "load_virtual_MB"
|
||||
REPORT_RUN_PHYSICAL_MEMORY_MB = "run_physical_MB"
|
||||
REPORT_RUN_VIRTUAL_MEMORY_MB = "run_virtual_MB"
|
||||
|
||||
PROMPTS = [
|
||||
"What is the meaning of life?",
|
||||
@@ -30,15 +58,27 @@ PROMPTS = [
|
||||
ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])
|
||||
|
||||
|
||||
def create_vmfb_module(model_name, tokenizer, device):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
|
||||
def get_memory_info():
|
||||
pid = os.getpid()
|
||||
process = psutil.Process(pid)
|
||||
return process.memory_info()
|
||||
|
||||
|
||||
def create_vmfb_module(
|
||||
model_name: str,
|
||||
tokenizer,
|
||||
device: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
|
||||
opt_base_model.eval()
|
||||
opt_model = OPTForCausalLMModel(opt_base_model)
|
||||
encoded_inputs = tokenizer(
|
||||
"What is the meaning of life?",
|
||||
PROMPTS[0],
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
max_length=max_seq_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
@@ -48,8 +88,16 @@ def create_vmfb_module(model_name, tokenizer, device):
|
||||
# np.save("model_inputs_0.npy", inputs[0])
|
||||
# np.save("model_inputs_1.npy", inputs[1])
|
||||
|
||||
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
|
||||
if os.path.isfile(mlir_path):
|
||||
opt_fs_name = get_opt_fs_name(model_name)
|
||||
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
|
||||
# If MLIR has already been loaded and recompilation is not requested, use
|
||||
# the loaded MLIR file.
|
||||
has_mlir = os.path.isfile(mlir_path)
|
||||
# The purpose of recompile_shark is to measure compilation time; the
|
||||
# compilation time can be correctly measured only when MLIR has already been
|
||||
# loaded.
|
||||
assert not recompile_shark or has_mlir
|
||||
if has_mlir:
|
||||
with open(mlir_path, "r") as f:
|
||||
model_mlir = f.read()
|
||||
print(f"Loaded .mlir from {mlir_path}")
|
||||
@@ -58,7 +106,7 @@ def create_vmfb_module(model_name, tokenizer, device):
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=OPT_FS_NAME,
|
||||
model_name=opt_fs_name,
|
||||
return_str=True,
|
||||
)
|
||||
with open(mlir_path, "w") as f:
|
||||
@@ -72,18 +120,25 @@ def create_vmfb_module(model_name, tokenizer, device):
|
||||
is_benchmark=False,
|
||||
)
|
||||
|
||||
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{DEVICE}_tiled_ukernels"
|
||||
vmfb_name = (
|
||||
f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels"
|
||||
)
|
||||
shark_module.save_module(module_name=vmfb_name)
|
||||
vmfb_path = vmfb_name + ".vmfb"
|
||||
return vmfb_path
|
||||
|
||||
|
||||
def load_shark_model() -> ModelWrapper:
|
||||
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{DEVICE}_tiled_ukernels.vmfb"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
|
||||
if not os.path.isfile(vmfb_name):
|
||||
def load_shark_model(
|
||||
model_name: str, max_seq_len: int, recompile_shark: bool
|
||||
) -> ModelWrapper:
|
||||
opt_fs_name = get_opt_fs_name(model_name)
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
if recompile_shark or not os.path.isfile(vmfb_name):
|
||||
print(f"vmfb not found. compiling and saving to {vmfb_name}")
|
||||
create_vmfb_module(OPT_MODELNAME, tokenizer, DEVICE)
|
||||
create_vmfb_module(
|
||||
model_name, tokenizer, DEVICE, max_seq_len, recompile_shark
|
||||
)
|
||||
shark_module = SharkInference(mlir_module=None, device="cpu-task")
|
||||
shark_module.load_module(vmfb_name)
|
||||
return ModelWrapper(model=shark_module, tokenizer=tokenizer)
|
||||
@@ -94,20 +149,10 @@ def run_shark_model(model_wrapper: ModelWrapper, tokens):
|
||||
return model_wrapper.model("forward", tokens)
|
||||
|
||||
|
||||
def run_shark():
|
||||
model_wrapper = load_shark_model()
|
||||
|
||||
prompt = "What is the meaning of life?"
|
||||
logits = run_shark_model(model_wrapper, prompt)
|
||||
|
||||
# Print output logits to validate vs. pytorch + base transformers
|
||||
print(logits[0])
|
||||
|
||||
|
||||
def load_huggingface_model() -> ModelWrapper:
|
||||
def load_huggingface_model(model_name: str) -> ModelWrapper:
|
||||
return ModelWrapper(
|
||||
model=OPTForCausalLM.from_pretrained(MODEL_NAME),
|
||||
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME),
|
||||
model=OPTForCausalLM.from_pretrained(model_name),
|
||||
tokenizer=AutoTokenizer.from_pretrained(model_name),
|
||||
)
|
||||
|
||||
|
||||
@@ -117,47 +162,71 @@ def run_huggingface_model(model_wrapper: ModelWrapper, tokens):
|
||||
)
|
||||
|
||||
|
||||
def run_huggingface():
|
||||
model_wrapper = load_huggingface_model()
|
||||
prompt = "What is the meaning of life?"
|
||||
logits = run_huggingface_model(model_wrapper, prompt)
|
||||
|
||||
print(logits[0])
|
||||
|
||||
|
||||
def save_json(data, filename):
|
||||
with open(filename, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
|
||||
def collect_huggingface_logits():
|
||||
def collect_huggingface_logits(
|
||||
model_name: str, max_seq_len: int, to_save_json: bool
|
||||
) -> Tuple[float, float]:
|
||||
# Load
|
||||
t0 = time.time()
|
||||
model_wrapper = load_huggingface_model()
|
||||
print("--- Took {} seconds to load Huggingface.".format(time.time() - t0))
|
||||
model_wrapper = load_huggingface_model(model_name)
|
||||
load_time = time.time() - t0
|
||||
print("--- Took {} seconds to load Huggingface.".format(load_time))
|
||||
load_memory_info = get_memory_info()
|
||||
|
||||
results = []
|
||||
tokenized_prompts = []
|
||||
for prompt in PROMPTS:
|
||||
tokens = model_wrapper.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
max_length=max_seq_len,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokenized_prompts.append(tokens)
|
||||
|
||||
# Run
|
||||
t0 = time.time()
|
||||
for idx, tokens in enumerate(tokenized_prompts):
|
||||
print("prompt: {}".format(PROMPTS[idx]))
|
||||
logits = run_huggingface_model(model_wrapper, tokens)
|
||||
results.append([PROMPTS[idx], logits[0].tolist()])
|
||||
print("--- Took {} seconds to run Huggingface.".format(time.time() - t0))
|
||||
save_json(results, "/tmp/huggingface.json")
|
||||
if to_save_json:
|
||||
results.append([PROMPTS[idx], logits[0].tolist()])
|
||||
run_time = time.time() - t0
|
||||
print("--- Took {} seconds to run Huggingface.".format(run_time))
|
||||
if to_save_json:
|
||||
save_json(results, "/tmp/huggingface.json")
|
||||
run_memory_info = get_memory_info()
|
||||
return {
|
||||
REPORT_PLATFORM: PLATFORM_HUGGINGFACE,
|
||||
REPORT_MODEL_NAME: model_name,
|
||||
REPORT_MAX_SEQ_LEN: max_seq_len,
|
||||
REPORT_LOAD_TIME: load_time,
|
||||
REPORT_RUN_TIME: run_time / len(PROMPTS),
|
||||
REPORT_LOAD_PHYSICAL_MEMORY_MB: load_memory_info.rss >> 20,
|
||||
REPORT_LOAD_VIRTUAL_MEMORY_MB: load_memory_info.vms >> 20,
|
||||
REPORT_RUN_PHYSICAL_MEMORY_MB: run_memory_info.rss >> 20,
|
||||
REPORT_RUN_VIRTUAL_MEMORY_MB: run_memory_info.vms >> 20,
|
||||
}
|
||||
|
||||
|
||||
def collect_shark_logits():
|
||||
def collect_shark_logits(
|
||||
model_name: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
to_save_json: bool,
|
||||
) -> Tuple[float, float]:
|
||||
# Load
|
||||
t0 = time.time()
|
||||
model_wrapper = load_shark_model()
|
||||
print("--- Took {} seconds to load Shark.".format(time.time() - t0))
|
||||
model_wrapper = load_shark_model(model_name, max_seq_len, recompile_shark)
|
||||
load_time = time.time() - t0
|
||||
print("--- Took {} seconds to load Shark.".format(load_time))
|
||||
load_memory_info = get_memory_info()
|
||||
|
||||
results = []
|
||||
tokenized_prompts = []
|
||||
for prompt in PROMPTS:
|
||||
@@ -165,7 +234,7 @@ def collect_shark_logits():
|
||||
prompt,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
max_length=max_seq_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
@@ -173,16 +242,100 @@ def collect_shark_logits():
|
||||
tokens["attention_mask"],
|
||||
)
|
||||
tokenized_prompts.append(inputs)
|
||||
|
||||
# Run
|
||||
t0 = time.time()
|
||||
for idx, tokens in enumerate(tokenized_prompts):
|
||||
print("prompt: {}".format(PROMPTS[idx]))
|
||||
logits = run_shark_model(model_wrapper, tokens)
|
||||
lst = [e.tolist() for e in logits]
|
||||
results.append([PROMPTS[idx], lst])
|
||||
print("--- Took {} seconds to run Shark.".format(time.time() - t0))
|
||||
save_json(results, "/tmp/shark.json")
|
||||
if to_save_json:
|
||||
results.append([PROMPTS[idx], lst])
|
||||
run_time = time.time() - t0
|
||||
print("--- Took {} seconds to run Shark.".format(run_time))
|
||||
if to_save_json:
|
||||
save_json(results, "/tmp/shark.json")
|
||||
platform_postfix = "-compile" if recompile_shark else "-precompiled"
|
||||
run_memory_info = get_memory_info()
|
||||
return {
|
||||
REPORT_PLATFORM: PLATFORM_SHARK + platform_postfix,
|
||||
REPORT_MODEL_NAME: model_name,
|
||||
REPORT_MAX_SEQ_LEN: max_seq_len,
|
||||
REPORT_LOAD_TIME: load_time,
|
||||
REPORT_RUN_TIME: run_time / len(PROMPTS),
|
||||
REPORT_LOAD_PHYSICAL_MEMORY_MB: load_memory_info.rss >> 20,
|
||||
REPORT_LOAD_VIRTUAL_MEMORY_MB: load_memory_info.vms >> 20,
|
||||
REPORT_RUN_PHYSICAL_MEMORY_MB: run_memory_info.rss >> 20,
|
||||
REPORT_RUN_VIRTUAL_MEMORY_MB: run_memory_info.vms >> 20,
|
||||
}
|
||||
|
||||
|
||||
def get_opt_fs_name(model_name: str) -> str:
|
||||
"""Cleanses the model name ino a file system-friendly name.
|
||||
|
||||
Example: get_opt_fs_name('facebook/opt-1.3b') == 'opt_1-3b'
|
||||
"""
|
||||
slash_split = model_name.split("/")
|
||||
assert 1 <= len(slash_split) <= 2, "There should be at most one slash."
|
||||
model_name = slash_split[-1]
|
||||
for src_pattern, dest_pattern in (("-", "_"), (".", "-")):
|
||||
model_name = model_name.replace(src_pattern, dest_pattern)
|
||||
return model_name
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save-json",
|
||||
help="If set, saves output JSON.",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-seq-len", help="Max sequence length", type=int, default=32
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
help="Model name",
|
||||
type=str,
|
||||
choices=[
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-350m",
|
||||
"facebook/opt-1.3b",
|
||||
"facebook/opt-6.7b",
|
||||
],
|
||||
default="facebook/opt-1.3b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recompile-shark",
|
||||
help="If set, recompiles MLIR",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--platform",
|
||||
help="Either shark or huggingface",
|
||||
type=str,
|
||||
choices=[PLATFORM_SHARK, PLATFORM_HUGGINGFACE],
|
||||
default=PLATFORM_SHARK,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("args={}".format(args))
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
collect_shark_logits()
|
||||
collect_huggingface_logits()
|
||||
args = parse_args()
|
||||
if args.platform == PLATFORM_SHARK:
|
||||
shark_report = collect_shark_logits(
|
||||
args.model_name,
|
||||
args.max_seq_len,
|
||||
args.recompile_shark,
|
||||
args.save_json,
|
||||
)
|
||||
print("# Summary: {}".format(json.dumps(shark_report)))
|
||||
else:
|
||||
huggingface_report = collect_huggingface_logits(
|
||||
args.model_name, args.max_seq_len, args.save_json
|
||||
)
|
||||
print("# Summary: {}".format(json.dumps(huggingface_report)))
|
||||
|
||||
30
tank/examples/opt/opt_perf_comparison_batch.py
Normal file
30
tank/examples/opt/opt_perf_comparison_batch.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
Script for running opt_perf_comparison.py in batch with a series of arguments.
|
||||
|
||||
Usage: python opt_perf_comparison_batch.py
|
||||
"""
|
||||
|
||||
from typing import Iterable, List
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
|
||||
def make_commands() -> Iterable[List[str]]:
|
||||
command = shlex.split("python opt_perf_comparison.py --no-save-json")
|
||||
max_seq_lens = [32, 128, 256, 512]
|
||||
model_names = ["facebook/opt-" + e for e in ["125m", "350m", "1.3b"]]
|
||||
for max_seq_len in max_seq_lens:
|
||||
for model_name in model_names:
|
||||
yield command + [
|
||||
f"--max-seq-len={max_seq_len}",
|
||||
f"--model-name={model_name}",
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
for command in make_commands():
|
||||
result = subprocess.run(command, check=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -16,12 +16,6 @@ import subprocess as sp
|
||||
import hashlib
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from apps.stable_diffusion.src.models import (
|
||||
model_wrappers as mw,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.stable_args import (
|
||||
args,
|
||||
)
|
||||
|
||||
|
||||
def create_hash(file_name):
|
||||
@@ -60,31 +54,6 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
print("generating artifacts for: " + torch_model_name)
|
||||
model = None
|
||||
input = None
|
||||
if model_type == "stable_diffusion":
|
||||
args.use_tuned = False
|
||||
args.import_mlir = True
|
||||
args.local_tank_cache = local_tank_cache
|
||||
|
||||
precision_values = ["fp16"]
|
||||
seq_lengths = [64, 77]
|
||||
for precision_value in precision_values:
|
||||
args.precision = precision_value
|
||||
for length in seq_lengths:
|
||||
model = mw.SharkifyStableDiffusionModel(
|
||||
model_id=torch_model_name,
|
||||
custom_weights="",
|
||||
precision=precision_value,
|
||||
max_len=length,
|
||||
width=512,
|
||||
height=512,
|
||||
use_base_vae=False,
|
||||
custom_vae="",
|
||||
debug=True,
|
||||
sharktank_dir=local_tank_cache,
|
||||
generate_vmfb=False,
|
||||
)
|
||||
model()
|
||||
continue
|
||||
if model_type == "vision":
|
||||
model, input, _ = get_vision_model(
|
||||
torch_model_name, import_args
|
||||
@@ -103,10 +72,11 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
model, input, _ = get_hf_img_cls_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "fp16":
|
||||
model, input, _ = get_fp16_model(torch_model_name, import_args)
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
if import_args["batch_size"] != 1:
|
||||
if import_args["batch_size"] > 1:
|
||||
print(
|
||||
f"Batch size for this model set to {import_args['batch_size']}"
|
||||
)
|
||||
torch_model_dir = os.path.join(
|
||||
local_tank_cache,
|
||||
str(torch_model_name)
|
||||
@@ -391,7 +361,7 @@ if __name__ == "__main__":
|
||||
|
||||
# old_import_args = parser.parse_import_args()
|
||||
import_args = {
|
||||
"batch_size": "1",
|
||||
"batch_size": 1,
|
||||
}
|
||||
print(import_args)
|
||||
home = str(Path.home())
|
||||
@@ -404,11 +374,6 @@ if __name__ == "__main__":
|
||||
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
|
||||
)
|
||||
|
||||
save_torch_model(
|
||||
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
|
||||
WORKDIR,
|
||||
import_args,
|
||||
)
|
||||
save_torch_model(torch_model_csv, WORKDIR, import_args)
|
||||
save_tf_model(tf_model_csv, WORKDIR, import_args)
|
||||
save_tflite_model(tflite_model_csv, WORKDIR, import_args)
|
||||
# save_tf_model(tf_model_csv, WORKDIR, import_args)
|
||||
# save_tflite_model(tflite_model_csv, WORKDIR, import_args)
|
||||
|
||||
@@ -278,7 +278,7 @@ def get_vision_model(torch_model, import_args):
|
||||
int(import_args["batch_size"]), 3, *input_image_size
|
||||
)
|
||||
actual_out = model(test_input)
|
||||
if fp16_model is not None:
|
||||
if fp16_model == True:
|
||||
test_input_fp16 = test_input.to(
|
||||
device=torch.device("cuda"), dtype=torch.half
|
||||
)
|
||||
|
||||
@@ -145,6 +145,7 @@ class SharkModuleTester:
|
||||
shark_args.shark_prefix = self.shark_tank_prefix
|
||||
shark_args.local_tank_cache = self.local_tank_cache
|
||||
shark_args.dispatch_benchmarks = self.benchmark_dispatches
|
||||
shark_args.enable_tf32 = self.tf32
|
||||
|
||||
if self.benchmark_dispatches is not None:
|
||||
_m = self.config["model_name"].split("/")
|
||||
@@ -216,10 +217,12 @@ class SharkModuleTester:
|
||||
|
||||
result = shark_module(func_name, inputs)
|
||||
golden_out, result = self.postprocess_outputs(golden_out, result)
|
||||
if self.tf32 == "true":
|
||||
print("Validating with relaxed tolerances.")
|
||||
atol = 1e-02
|
||||
rtol = 1e-03
|
||||
if self.tf32 == True:
|
||||
print(
|
||||
"Validating with relaxed tolerances for TensorFloat32 calculations."
|
||||
)
|
||||
self.config["atol"] = 1e-01
|
||||
self.config["rtol"] = 1e-02
|
||||
try:
|
||||
np.testing.assert_allclose(
|
||||
golden_out,
|
||||
@@ -254,9 +257,6 @@ class SharkModuleTester:
|
||||
model_config = {
|
||||
"batch_size": self.batch_size,
|
||||
}
|
||||
shark_args.enable_tf32 = self.tf32
|
||||
if shark_args.enable_tf32 == True:
|
||||
shark_module.compile()
|
||||
|
||||
shark_args.onnx_bench = self.onnx_bench
|
||||
shark_module.shark_runner.benchmark_all_csv(
|
||||
@@ -287,6 +287,9 @@ class SharkModuleTester:
|
||||
repro_path = os.path.join("reproducers", self.tmp_prefix, "*")
|
||||
|
||||
bashCommand = f"gsutil cp -r {repro_path} gs://shark-public/builder/repro_artifacts/{self.ci_sha}/{self.tmp_prefix}/"
|
||||
print(
|
||||
f"Uploading reproducer {repro_path} to gs://shark-public/builder/repro_artifacts/{self.ci_sha}/{self.tmp_prefix}/"
|
||||
)
|
||||
process = subprocess.run(bashCommand.split())
|
||||
|
||||
def postprocess_outputs(self, golden_out, result):
|
||||
|
||||
@@ -5,7 +5,6 @@ microsoft/MiniLM-L12-H384-uncased,True,hf,True,linalg,False,66M,"nlp;bert-varian
|
||||
bert-base-uncased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
google/mobilebert-uncased,True,hf,True,linalg,False,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
|
||||
alexnet,False,vision,True,linalg,False,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
|
||||
resnet18,False,vision,True,linalg,False,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
|
||||
resnet50,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
resnet101,False,vision,True,linalg,False,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
@@ -18,11 +17,9 @@ facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,linalg,False,22M
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,linalg,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
|
||||
nvidia/mit-b0,True,hf_img_cls,False,linalg,False,3.7M,"image-classification,transformer-encoder",SegFormer
|
||||
mnasnet1_0,False,vision,True,linalg,False,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
|
||||
|
Reference in New Issue
Block a user