mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
52 Commits
20230824.9
...
20231004.9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c4610b9da | ||
|
|
a38cc9d216 | ||
|
|
1c382449ec | ||
|
|
7cc9b3f8e8 | ||
|
|
e54517e967 | ||
|
|
326327a799 | ||
|
|
785b65c7b0 | ||
|
|
0d16c81687 | ||
|
|
8dd7850c69 | ||
|
|
e930ba85b4 | ||
|
|
cd732e7a38 | ||
|
|
8e0f8b3227 | ||
|
|
b8210ef796 | ||
|
|
94594542a9 | ||
|
|
82f833e87d | ||
|
|
c9d6870105 | ||
|
|
4fec03a6cc | ||
|
|
9a27f51378 | ||
|
|
ad1a0f35ff | ||
|
|
6773278ec2 | ||
|
|
9a0efffcca | ||
|
|
61c6f153d9 | ||
|
|
effd42e8f5 | ||
|
|
b5fbb1a8a0 | ||
|
|
ded74d09cd | ||
|
|
79267931c1 | ||
|
|
9eceba69b7 | ||
|
|
ca609afb6a | ||
|
|
11bdce9790 | ||
|
|
684943a4a6 | ||
|
|
b817bb8455 | ||
|
|
780f520f02 | ||
|
|
c61b6f8d65 | ||
|
|
c854208d49 | ||
|
|
c5dcfc1f13 | ||
|
|
bde63ee8ae | ||
|
|
9681d494eb | ||
|
|
ede6bf83e2 | ||
|
|
2c2693fb7d | ||
|
|
1d31b2b2c6 | ||
|
|
d2f64eefa3 | ||
|
|
87ae14b6ff | ||
|
|
1ccafa1fc1 | ||
|
|
4c3d8a0a7f | ||
|
|
3601dc7c3b | ||
|
|
671881cf87 | ||
|
|
4e9be6be59 | ||
|
|
9c8cbaf498 | ||
|
|
9e348a114e | ||
|
|
51f90a4d56 | ||
|
|
310d5d0a49 | ||
|
|
9697981004 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -196,3 +196,6 @@ db_dir_UserData
|
||||
|
||||
# Embeded browser cache and other
|
||||
apps/stable_diffusion/web/EBWebView/
|
||||
|
||||
# Llama2 tokenizer configs
|
||||
llama2_tokenizer_configs/
|
||||
|
||||
@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
|
||||
<summary>Prerequisites - Drivers </summary>
|
||||
|
||||
#### Install your Windows hardware drivers
|
||||
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
|
||||
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
|
||||
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
|
||||
@@ -237,7 +237,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -46,6 +46,7 @@ def compile_stableLM(
|
||||
model_vmfb_name,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug=False,
|
||||
):
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
@@ -92,7 +93,7 @@ def compile_stableLM(
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
94
apps/language_models/shark_llama_cli.spec
Normal file
94
apps/language_models/shark_llama_cli.spec
Normal file
@@ -0,0 +1,94 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += copy_metadata('huggingface-hub')
|
||||
datas += copy_metadata('sentencepiece')
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('py-cpuinfo')
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/vicuna.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_llama_cli',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
@@ -57,8 +57,6 @@ from shark.shark_importer import get_f16_inputs
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
class FirstVicuna(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -21,7 +18,13 @@ class FirstVicuna(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("First Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
@@ -64,7 +67,13 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
@@ -148,8 +157,6 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
i63,
|
||||
i64,
|
||||
):
|
||||
# input_ids = input_tuple[0]
|
||||
# input_tuple = torch.unbind(pkv, dim=0)
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
@@ -294,7 +301,7 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
precision="int8",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -307,6 +314,11 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
@@ -406,8 +418,6 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
i79,
|
||||
i80,
|
||||
):
|
||||
# input_ids = input_tuple[0]
|
||||
# input_tuple = torch.unbind(pkv, dim=0)
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
@@ -580,6 +590,540 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna70B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float16,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
i0,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
i65,
|
||||
i66,
|
||||
i67,
|
||||
i68,
|
||||
i69,
|
||||
i70,
|
||||
i71,
|
||||
i72,
|
||||
i73,
|
||||
i74,
|
||||
i75,
|
||||
i76,
|
||||
i77,
|
||||
i78,
|
||||
i79,
|
||||
i80,
|
||||
i81,
|
||||
i82,
|
||||
i83,
|
||||
i84,
|
||||
i85,
|
||||
i86,
|
||||
i87,
|
||||
i88,
|
||||
i89,
|
||||
i90,
|
||||
i91,
|
||||
i92,
|
||||
i93,
|
||||
i94,
|
||||
i95,
|
||||
i96,
|
||||
i97,
|
||||
i98,
|
||||
i99,
|
||||
i100,
|
||||
i101,
|
||||
i102,
|
||||
i103,
|
||||
i104,
|
||||
i105,
|
||||
i106,
|
||||
i107,
|
||||
i108,
|
||||
i109,
|
||||
i110,
|
||||
i111,
|
||||
i112,
|
||||
i113,
|
||||
i114,
|
||||
i115,
|
||||
i116,
|
||||
i117,
|
||||
i118,
|
||||
i119,
|
||||
i120,
|
||||
i121,
|
||||
i122,
|
||||
i123,
|
||||
i124,
|
||||
i125,
|
||||
i126,
|
||||
i127,
|
||||
i128,
|
||||
i129,
|
||||
i130,
|
||||
i131,
|
||||
i132,
|
||||
i133,
|
||||
i134,
|
||||
i135,
|
||||
i136,
|
||||
i137,
|
||||
i138,
|
||||
i139,
|
||||
i140,
|
||||
i141,
|
||||
i142,
|
||||
i143,
|
||||
i144,
|
||||
i145,
|
||||
i146,
|
||||
i147,
|
||||
i148,
|
||||
i149,
|
||||
i150,
|
||||
i151,
|
||||
i152,
|
||||
i153,
|
||||
i154,
|
||||
i155,
|
||||
i156,
|
||||
i157,
|
||||
i158,
|
||||
i159,
|
||||
i160,
|
||||
):
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
(
|
||||
i65,
|
||||
i66,
|
||||
),
|
||||
(
|
||||
i67,
|
||||
i68,
|
||||
),
|
||||
(
|
||||
i69,
|
||||
i70,
|
||||
),
|
||||
(
|
||||
i71,
|
||||
i72,
|
||||
),
|
||||
(
|
||||
i73,
|
||||
i74,
|
||||
),
|
||||
(
|
||||
i75,
|
||||
i76,
|
||||
),
|
||||
(
|
||||
i77,
|
||||
i78,
|
||||
),
|
||||
(
|
||||
i79,
|
||||
i80,
|
||||
),
|
||||
(
|
||||
i81,
|
||||
i82,
|
||||
),
|
||||
(
|
||||
i83,
|
||||
i84,
|
||||
),
|
||||
(
|
||||
i85,
|
||||
i86,
|
||||
),
|
||||
(
|
||||
i87,
|
||||
i88,
|
||||
),
|
||||
(
|
||||
i89,
|
||||
i90,
|
||||
),
|
||||
(
|
||||
i91,
|
||||
i92,
|
||||
),
|
||||
(
|
||||
i93,
|
||||
i94,
|
||||
),
|
||||
(
|
||||
i95,
|
||||
i96,
|
||||
),
|
||||
(
|
||||
i97,
|
||||
i98,
|
||||
),
|
||||
(
|
||||
i99,
|
||||
i100,
|
||||
),
|
||||
(
|
||||
i101,
|
||||
i102,
|
||||
),
|
||||
(
|
||||
i103,
|
||||
i104,
|
||||
),
|
||||
(
|
||||
i105,
|
||||
i106,
|
||||
),
|
||||
(
|
||||
i107,
|
||||
i108,
|
||||
),
|
||||
(
|
||||
i109,
|
||||
i110,
|
||||
),
|
||||
(
|
||||
i111,
|
||||
i112,
|
||||
),
|
||||
(
|
||||
i113,
|
||||
i114,
|
||||
),
|
||||
(
|
||||
i115,
|
||||
i116,
|
||||
),
|
||||
(
|
||||
i117,
|
||||
i118,
|
||||
),
|
||||
(
|
||||
i119,
|
||||
i120,
|
||||
),
|
||||
(
|
||||
i121,
|
||||
i122,
|
||||
),
|
||||
(
|
||||
i123,
|
||||
i124,
|
||||
),
|
||||
(
|
||||
i125,
|
||||
i126,
|
||||
),
|
||||
(
|
||||
i127,
|
||||
i128,
|
||||
),
|
||||
(
|
||||
i129,
|
||||
i130,
|
||||
),
|
||||
(
|
||||
i131,
|
||||
i132,
|
||||
),
|
||||
(
|
||||
i133,
|
||||
i134,
|
||||
),
|
||||
(
|
||||
i135,
|
||||
i136,
|
||||
),
|
||||
(
|
||||
i137,
|
||||
i138,
|
||||
),
|
||||
(
|
||||
i139,
|
||||
i140,
|
||||
),
|
||||
(
|
||||
i141,
|
||||
i142,
|
||||
),
|
||||
(
|
||||
i143,
|
||||
i144,
|
||||
),
|
||||
(
|
||||
i145,
|
||||
i146,
|
||||
),
|
||||
(
|
||||
i147,
|
||||
i148,
|
||||
),
|
||||
(
|
||||
i149,
|
||||
i150,
|
||||
),
|
||||
(
|
||||
i151,
|
||||
i152,
|
||||
),
|
||||
(
|
||||
i153,
|
||||
i154,
|
||||
),
|
||||
(
|
||||
i155,
|
||||
i156,
|
||||
),
|
||||
(
|
||||
i157,
|
||||
i158,
|
||||
),
|
||||
(
|
||||
i159,
|
||||
i160,
|
||||
),
|
||||
)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class CombinedModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -3,7 +3,10 @@ from abc import ABC, abstractmethod
|
||||
|
||||
class SharkLLMBase(ABC):
|
||||
def __init__(
|
||||
self, model_name, hf_model_path=None, max_num_tokens=512
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path=None,
|
||||
max_num_tokens=512,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.hf_model_path = hf_model_path
|
||||
|
||||
@@ -28,7 +28,9 @@ parser = argparse.ArgumentParser(
|
||||
description="runs a falcon model",
|
||||
)
|
||||
|
||||
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
|
||||
parser.add_argument(
|
||||
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
|
||||
)
|
||||
@@ -49,7 +51,7 @@ parser.add_argument(
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=False,
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
@@ -59,32 +61,52 @@ parser.add_argument(
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model in cli mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication token for falcon-180B model.",
|
||||
)
|
||||
|
||||
|
||||
class Falcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path,
|
||||
hf_model_path="tiiuae/falcon-7b-instruct",
|
||||
hf_auth_token: str = None,
|
||||
max_num_tokens=150,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print("hf_model_path: ", self.hf_model_path)
|
||||
|
||||
if "180b" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
""" HF auth token required for falcon-180b. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.falcon_vmfb_path = falcon_vmfb_path
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, trust_remote_code=True
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
token=self.hf_auth_token,
|
||||
)
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token_id = 11
|
||||
@@ -92,13 +114,18 @@ class Falcon(SharkLLMBase):
|
||||
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
"device_map": "cpu" if args.device == "cpu" else "cuda:0",
|
||||
}
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
return falcon_model
|
||||
|
||||
def compile_falcon(self):
|
||||
def compile(self):
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
@@ -120,37 +147,39 @@ class Falcon(SharkLLMBase):
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
|
||||
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
if args.load_mlir_from_shark_tank:
|
||||
# Downloading MLIR from shark_tank
|
||||
print(f"[DEBUG] Trying to download mlir from shark_tank")
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(
|
||||
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] generating MLIR locally")
|
||||
compilation_input_ids = torch.randint(
|
||||
low=1, high=10000, size=(1, 100)
|
||||
)
|
||||
@@ -170,6 +199,7 @@ class Falcon(SharkLLMBase):
|
||||
is_f16=self.precision == "fp16",
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
is_gptq=self.precision == "int4",
|
||||
)
|
||||
del model
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
@@ -189,10 +219,9 @@ class Falcon(SharkLLMBase):
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
print(f"[DEBUG] writing mlir to file")
|
||||
with open(f"{self.model_name}.mlir", "wb") as f_:
|
||||
with redirect_stdout(f_):
|
||||
print(module.operation.get_asm())
|
||||
f_ = open(self.falcon_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
@@ -202,22 +231,17 @@ class Falcon(SharkLLMBase):
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
self.falcon_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-spirv-index-bits=64",
|
||||
],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile(self):
|
||||
falcon_shark_model = self.compile_falcon()
|
||||
return falcon_shark_model
|
||||
|
||||
def generate(self, prompt):
|
||||
model_inputs = self.tokenizer(
|
||||
prompt,
|
||||
@@ -466,11 +490,26 @@ if __name__ == "__main__":
|
||||
else Path(args.falcon_vmfb_path)
|
||||
)
|
||||
|
||||
if args.precision == "int4":
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"TheBloke/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct-GPTQ"
|
||||
)
|
||||
else:
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "tiiuae/falcon-180B-chat"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
|
||||
)
|
||||
|
||||
falcon = Falcon(
|
||||
"falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path="tiiuae/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct",
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
@@ -497,7 +536,11 @@ if __name__ == "__main__":
|
||||
prompt = input("Please enter the prompt text: ")
|
||||
print("\nPrompt Text: ", prompt)
|
||||
|
||||
res_str = falcon.generate(prompt)
|
||||
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
|
||||
User: {prompt}
|
||||
Assistant:"""
|
||||
|
||||
res_str = falcon.generate(prompt_template)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
print(
|
||||
|
||||
@@ -178,7 +178,7 @@ def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
|
||||
|
||||
def compile_module(
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[]
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[], debug=False,
|
||||
):
|
||||
if generate_vmfb:
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
|
||||
@@ -190,7 +190,7 @@ def compile_module(
|
||||
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_model_name, extra_args
|
||||
os.getcwd(), extended_model_name, extra_args, debug=debug
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -199,7 +199,7 @@ def compile_module(
|
||||
|
||||
|
||||
def compile_int_precision(
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name, debug=False
|
||||
):
|
||||
torchscript_module = import_with_fx(
|
||||
model,
|
||||
@@ -219,7 +219,7 @@ def compile_int_precision(
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
@@ -251,6 +251,7 @@ def compile_int_precision(
|
||||
extended_model_name=extended_model_name,
|
||||
generate_vmfb=generate_vmfb,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
),
|
||||
bytecode,
|
||||
)
|
||||
@@ -294,6 +295,7 @@ def shark_compile_through_fx_int(
|
||||
device,
|
||||
generate_or_load_vmfb,
|
||||
extended_model_name,
|
||||
debug,
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
|
||||
@@ -32,11 +32,13 @@ class SharkStableLM(SharkLLMBase):
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug="False",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
@@ -111,7 +113,7 @@ class SharkStableLM(SharkLLMBase):
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from shark.shark_downloader import download_public_file
|
||||
|
||||
# expects a Path / str as arg
|
||||
# returns None if path not found or SharkInference module
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
|
||||
if not isinstance(vmfb_path, Path):
|
||||
vmfb_path = Path(vmfb_path)
|
||||
|
||||
@@ -20,7 +20,7 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
print("Device from get_vmfb_from_path - ", device)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect=mlir_dialect
|
||||
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
@@ -28,7 +28,13 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
|
||||
|
||||
def get_vmfb_from_config(
|
||||
shark_container, model, precision, device, vmfb_path, padding=None
|
||||
shark_container,
|
||||
model,
|
||||
precision,
|
||||
device,
|
||||
vmfb_path,
|
||||
padding=None,
|
||||
device_id=None,
|
||||
):
|
||||
vmfb_url = (
|
||||
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
|
||||
@@ -37,4 +43,6 @@ def get_vmfb_from_config(
|
||||
vmfb_url = vmfb_url + f"_{padding}"
|
||||
vmfb_url = vmfb_url + ".vmfb"
|
||||
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
|
||||
return get_vmfb_from_path(vmfb_path, device, "tm_tensor")
|
||||
return get_vmfb_from_path(
|
||||
vmfb_path, device, "tm_tensor", device_id=device_id
|
||||
)
|
||||
|
||||
@@ -15,8 +15,8 @@ pathex = [
|
||||
|
||||
# datafiles for pyinstaller
|
||||
datas = []
|
||||
datas += collect_data_files("torch")
|
||||
datas += copy_metadata("torch")
|
||||
datas += copy_metadata("tokenizers")
|
||||
datas += copy_metadata("tqdm")
|
||||
datas += copy_metadata("regex")
|
||||
datas += copy_metadata("requests")
|
||||
@@ -31,18 +31,17 @@ datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files("diffusers")
|
||||
datas += collect_data_files("transformers")
|
||||
datas += collect_data_files("pytorch_lightning")
|
||||
datas += collect_data_files("opencv_python")
|
||||
datas += collect_data_files("skimage")
|
||||
datas += collect_data_files("gradio")
|
||||
datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree")
|
||||
datas += collect_data_files("google_cloud_storage")
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
@@ -53,6 +52,7 @@ datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
datas += collect_data_files("cv2")
|
||||
datas += [
|
||||
("src/utils/resources/prompts.json", "resources"),
|
||||
("src/utils/resources/model_db.json", "resources"),
|
||||
@@ -75,7 +75,13 @@ datas += [
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("transformers") if "tests" not in x
|
||||
x for x in collect_submodules("diffusers") if "tests" not in x
|
||||
]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
for x in collect_submodules("transformers")
|
||||
if not any(kw in x for kw in blacklist)
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
hiddenimports += ["iree._runtime", "iree._runtime_libs"]
|
||||
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]
|
||||
|
||||
@@ -273,6 +273,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
resample_type,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# TODO: 1. Change `num_images_per_prompt`.
|
||||
|
||||
@@ -158,9 +158,9 @@ def load_lower_configs(base_model_id=None):
|
||||
f"{spec}.json"
|
||||
)
|
||||
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading lowering config file from ", lowering_config_dir)
|
||||
full_gs_url = config_bucket + config_name
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
return lowering_config_dir
|
||||
|
||||
@@ -281,13 +281,9 @@ def sd_model_annotation(mlir_model, model_name, base_model_id=None):
|
||||
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
|
||||
use_winograd = True
|
||||
winograd_config_dir = load_winograd_configs()
|
||||
winograd_model = annotate_with_winograd(
|
||||
tuned_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
winograd_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
else:
|
||||
tuned_model = mlir_model
|
||||
else:
|
||||
|
||||
@@ -458,6 +458,14 @@ p.add_argument(
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device_allocator_heap_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify heap key for device caching allocator."
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
@@ -570,6 +578,14 @@ p.add_argument(
|
||||
"in shark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--compile_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to toggle debug assert/verify flags for imported IR in the"
|
||||
"iree-compiler. Default to false.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_constant_folding",
|
||||
default=True,
|
||||
@@ -625,6 +641,13 @@ p.add_argument(
|
||||
help="Flag for enabling rest API.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling debugging log in WebUI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery",
|
||||
default=True,
|
||||
|
||||
@@ -25,7 +25,7 @@ from shark.iree_utils.vulkan_utils import (
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
from shark.iree_utils.metal_utils import get_metal_target_triple
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc, get_iree_rocm_args
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
@@ -78,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), model_name, extra_args
|
||||
os.getcwd(), model_name, extra_args, debug=args.compile_debug
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -184,12 +184,18 @@ def compile_through_fx(
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if args.device_allocator_heap_key:
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={args.device_allocator_heap_key}",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
@@ -470,12 +476,25 @@ def get_available_devices():
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
vulkan_devices = get_devices_by_name("vulkan")
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
@@ -499,7 +518,10 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
if "rocm" in args.device:
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
print(iree_flags)
|
||||
if args.iree_constant_folding == False:
|
||||
iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
iree_flags.append(
|
||||
@@ -572,7 +594,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
)
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=custom_weights,
|
||||
checkpoint_path_or_dict=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
num_in_channels=num_in_channels,
|
||||
@@ -822,6 +844,8 @@ def clear_all():
|
||||
elif os.name == "unix":
|
||||
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
if args.local_tank_cache != "":
|
||||
shutil.rmtree(args.local_tank_cache)
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
if sys.platform == "darwin":
|
||||
# import before IREE to avoid torch-MLIR library issues
|
||||
@@ -41,6 +42,8 @@ def launch_app(address):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.debug:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
if args.api or "api" in args.ui.split(","):
|
||||
@@ -153,9 +156,9 @@ if __name__ == "__main__":
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
lora_train_web,
|
||||
model_web,
|
||||
model_config_web,
|
||||
# lora_train_web,
|
||||
# model_web,
|
||||
# model_config_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
@@ -247,16 +250,16 @@ if __name__ == "__main__":
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
with gr.TabItem(label="Model Manager", id=6):
|
||||
model_web.render()
|
||||
with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot (Experimental)", id=8):
|
||||
# with gr.TabItem(label="Model Manager", id=6):
|
||||
# model_web.render()
|
||||
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
# lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=8):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(
|
||||
label="Generate Sharding Config (Experimental)", id=9
|
||||
):
|
||||
model_config_web.render()
|
||||
# with gr.TabItem(
|
||||
# label="Generate Sharding Config (Experimental)", id=9
|
||||
# ):
|
||||
# model_config_web.render()
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
|
||||
@@ -109,7 +109,7 @@ with gr.Blocks() as minigpt4_web:
|
||||
gr.Markdown(description)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
with gr.Column():
|
||||
image = gr.Image(type="pil")
|
||||
upload_button = gr.Button(
|
||||
value="Upload & Start Chat",
|
||||
|
||||
@@ -8,7 +8,7 @@ from transformers import (
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import time
|
||||
import sys
|
||||
|
||||
|
||||
def user(message, history):
|
||||
@@ -26,11 +26,7 @@ model_map = {
|
||||
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"llama2_13b": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"codegen": "Salesforce/codegen25-7b-multi",
|
||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"vicuna4": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
|
||||
}
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
@@ -62,61 +58,39 @@ start_message = {
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"StableLM": (
|
||||
"<|SYSTEM|># StableLM Tuned (Alpha version)"
|
||||
"\n- StableLM is a helpful and harmless open-source AI language model "
|
||||
"developed by StabilityAI."
|
||||
"\n- StableLM is excited to be able to help the user, but will refuse "
|
||||
"to do anything that could be considered harmful to the user."
|
||||
"\n- StableLM is more than just an information source, StableLM is also "
|
||||
"able to write poetry, short stories, and make jokes."
|
||||
"\n- StableLM will refuse to participate in anything that "
|
||||
"could harm a human."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna4": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna1p3": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"codegen": "",
|
||||
}
|
||||
|
||||
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if model_name in [
|
||||
"StableLM",
|
||||
"vicuna",
|
||||
"vicuna4",
|
||||
"vicuna1p3",
|
||||
"llama2_7b",
|
||||
"llama2_13b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
if "llama2" in model_name:
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
conversation = "".join(
|
||||
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
|
||||
)
|
||||
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
elif model_name in ["vicuna"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
else:
|
||||
conversation = "".join(
|
||||
["".join([item[0], item[1]]) for item in history]
|
||||
)
|
||||
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
@@ -160,14 +134,16 @@ def chat(
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
cli=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global past_key_values
|
||||
global model_vmfb_key
|
||||
|
||||
global vicuna_model
|
||||
|
||||
device_id = None
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
@@ -176,125 +152,125 @@ def chat(
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{precision}"
|
||||
if model_name in [
|
||||
"vicuna",
|
||||
"vicuna4",
|
||||
"vicuna1p3",
|
||||
"codegen",
|
||||
"llama2_7b",
|
||||
"llama2_13b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
if args.iree_vulkan_target_triple != "":
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
# else:
|
||||
# if config_file is not None:
|
||||
# config_file = open(config_file)
|
||||
# config_json = json.load(config_file)
|
||||
# config_file.close()
|
||||
# else:
|
||||
# config_json = get_default_config()
|
||||
# vicuna_model = ShardedVicuna(
|
||||
# model_name,
|
||||
# device=device,
|
||||
# precision=precision,
|
||||
# config_json=config_json,
|
||||
# )
|
||||
|
||||
prompt = create_prompt(model_name, history)
|
||||
|
||||
partial_text = ""
|
||||
count = 0
|
||||
start_time = time.time()
|
||||
for text, msg in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
count += 1
|
||||
if "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
end_time = time.time()
|
||||
tokens_per_sec = count / (end_time - start_time)
|
||||
yield history, str(
|
||||
format(tokens_per_sec, ".2f")
|
||||
) + " tokens/sec"
|
||||
else:
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, ""
|
||||
|
||||
return history, ""
|
||||
|
||||
# else Model is StableLM
|
||||
global sharkModel
|
||||
from apps.language_models.src.pipelines.stablelm_pipeline import (
|
||||
SharkStableLM,
|
||||
)
|
||||
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
|
||||
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
# max_new_tokens=512
|
||||
shark_slm = SharkStableLM(
|
||||
model_name
|
||||
) # pass elements from UI as required
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
if device == "vulkan":
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if vulkan_target_triple == "":
|
||||
# We already have the device_id extracted via WebUI, so we directly use
|
||||
# that to find the target triple.
|
||||
vulkan_target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[device_id]
|
||||
)
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={vulkan_target_triple}"
|
||||
)
|
||||
if "rdna" in vulkan_target_triple:
|
||||
flags_to_add = [
|
||||
"--iree-spirv-index-bits=64",
|
||||
]
|
||||
_extra_args = _extra_args + flags_to_add
|
||||
|
||||
if device_id is None:
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[id]
|
||||
)
|
||||
if target_triple == vulkan_target_triple:
|
||||
device_id = id
|
||||
break
|
||||
id += 1
|
||||
|
||||
assert (
|
||||
device_id
|
||||
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
|
||||
print(f"Will use target triple : {vulkan_target_triple}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=download_vmfb,
|
||||
load_mlir_from_shark_tank=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
if vicuna_model is None:
|
||||
sys.exit("Unable to instantiate the model object, exiting.")
|
||||
|
||||
# Construct the input message string for the model by concatenating the
|
||||
# current system message and conversation history
|
||||
if len(curr_system_message.split()) > 160:
|
||||
print("clearing context")
|
||||
prompt = create_prompt(model_name, history)
|
||||
generate_kwargs = dict(prompt=prompt)
|
||||
|
||||
words_list = shark_slm.generate(**generate_kwargs)
|
||||
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to clean up the message textbox and the updated
|
||||
# conversation history
|
||||
yield history
|
||||
return words_list
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, f"Prefill: {prefill_time:.2f}"
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
|
||||
return history, ""
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
@@ -330,6 +306,7 @@ def llm_chat_api(InputData: dict):
|
||||
UnshardedVicuna,
|
||||
)
|
||||
|
||||
device_id = None
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
@@ -338,6 +315,7 @@ def llm_chat_api(InputData: dict):
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
@@ -348,6 +326,9 @@ def llm_chat_api(InputData: dict):
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=True,
|
||||
load_mlir_from_shark_tank=True,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
# TODO: add role dict for different models
|
||||
@@ -410,7 +391,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
)
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value=model_choices[4],
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
)
|
||||
supported_devices = available_devices
|
||||
@@ -418,27 +399,31 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
# print(supported_devices)
|
||||
devices = gr.Dropdown(
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
# multiselect=True,
|
||||
# multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="int8",
|
||||
value="int4",
|
||||
choices=[
|
||||
"int4",
|
||||
"int8",
|
||||
"fp16",
|
||||
],
|
||||
visible=True,
|
||||
visible=False,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
@@ -470,19 +455,45 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
inputs=[
|
||||
system_msg,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
inputs=[
|
||||
system_msg,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
|
||||
project(sharkbackend LANGUAGES C CXX)
|
||||
|
||||
#
|
||||
# Options
|
||||
#
|
||||
|
||||
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
|
||||
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
|
||||
|
||||
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
|
||||
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
|
||||
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Dependencies
|
||||
#
|
||||
# FetchContent requires us to include the transitive closure of all
|
||||
# repos that we depend on so that we can override the tags.
|
||||
#
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-common
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
|
||||
GIT_TAG ${TRITON_COMMON_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-core
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
|
||||
GIT_TAG ${TRITON_CORE_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-backend
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
|
||||
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
|
||||
|
||||
#
|
||||
# The backend must be built into a shared library. Use an ldscript to
|
||||
# hide all symbols except for the TRITONBACKEND API.
|
||||
#
|
||||
configure_file(src/libtriton_dshark.ldscript libtriton_dshark.ldscript COPYONLY)
|
||||
|
||||
add_library(
|
||||
triton-dshark-backend SHARED
|
||||
src/dshark.cc
|
||||
#src/dshark_driver_module.c
|
||||
)
|
||||
|
||||
add_library(
|
||||
SharkBackend::triton-dshark-backend ALIAS triton-dshark-backend
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
triton-dshark-backend
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
|
||||
|
||||
add_subdirectory(thirdparty/srt EXCLUDE_FROM_ALL)
|
||||
|
||||
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
|
||||
iree_hal_hal
|
||||
iree_hal_cuda_cuda
|
||||
iree_hal_cuda_registration_registration
|
||||
iree_hal_vmvx_registration_registration
|
||||
iree_hal_dylib_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_hal_local_loaders_system_library_loader
|
||||
iree_hal_local_loaders_vmvx_module_loader
|
||||
)
|
||||
|
||||
target_compile_features(triton-dshark-backend PRIVATE cxx_std_11)
|
||||
|
||||
|
||||
target_link_libraries(
|
||||
triton-dshark-backend
|
||||
PRIVATE
|
||||
triton-core-serverapi # from repo-core
|
||||
triton-core-backendapi # from repo-core
|
||||
triton-core-serverstub # from repo-core
|
||||
triton-backend-utils # from repo-backend
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
set_target_properties(
|
||||
triton-dshark-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_dshark
|
||||
)
|
||||
else()
|
||||
set_target_properties(
|
||||
triton-dshark-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_dshark
|
||||
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_dshark.ldscript
|
||||
LINK_FLAGS "-Wl,--version-script libtriton_dshark.ldscript"
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
#
|
||||
# Install
|
||||
#
|
||||
include(GNUInstallDirs)
|
||||
set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/SharkBackend)
|
||||
|
||||
install(
|
||||
TARGETS
|
||||
triton-dshark-backend
|
||||
EXPORT
|
||||
triton-dshark-backend-targets
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
|
||||
)
|
||||
|
||||
install(
|
||||
EXPORT
|
||||
triton-dshark-backend-targets
|
||||
FILE
|
||||
SharkBackendTargets.cmake
|
||||
NAMESPACE
|
||||
SharkBackend::
|
||||
DESTINATION
|
||||
${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
include(CMakePackageConfigHelpers)
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_LIST_DIR}/cmake/SharkBackendConfig.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
|
||||
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
install(
|
||||
FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
|
||||
DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
#
|
||||
# Export from build tree
|
||||
#
|
||||
export(
|
||||
EXPORT triton-dshark-backend-targets
|
||||
FILE ${CMAKE_CURRENT_BINARY_DIR}/SharkBackendTargets.cmake
|
||||
NAMESPACE SharkBackend::
|
||||
)
|
||||
|
||||
export(PACKAGE SharkBackend)
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
# SHARK Triton Backend
|
||||
|
||||
The triton backend for shark.
|
||||
|
||||
# Build
|
||||
|
||||
Install SHARK
|
||||
|
||||
```
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
# skip above step if dshark is already installed
|
||||
cd SHARK/inference
|
||||
```
|
||||
|
||||
install dependancies
|
||||
|
||||
```
|
||||
apt-get install patchelf rapidjson-dev python3-dev
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
update the submodules of iree
|
||||
|
||||
```
|
||||
cd thirdparty/srt
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
Next, make the backend and install it
|
||||
|
||||
```
|
||||
cd ../..
|
||||
mkdir build && cd build
|
||||
cmake -DTRITON_ENABLE_GPU=ON \
|
||||
-DIREE_HAL_DRIVER_CUDA=ON \
|
||||
-DIREE_TARGET_BACKEND_CUDA=ON \
|
||||
-DMLIR_ENABLE_CUDA_RUNNER=ON \
|
||||
-DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \
|
||||
-DTRITON_BACKEND_REPO_TAG=r22.02 \
|
||||
-DTRITON_CORE_REPO_TAG=r22.02 \
|
||||
-DTRITON_COMMON_REPO_TAG=r22.02 ..
|
||||
make install
|
||||
```
|
||||
|
||||
# Incorporating into Triton
|
||||
|
||||
There are much more in depth explenations for the following steps in triton's documentation:
|
||||
https://github.com/triton-inference-server/server/blob/main/docs/compose.md#triton-with-unsupported-and-custom-backends
|
||||
|
||||
There should be a file at /build/install/backends/dshark/libtriton_dshark.so. You will need to copy it into your triton server image.
|
||||
More documentation is in the link above, but to create the docker image, you need to run the compose.py command in the triton-backend server repo
|
||||
|
||||
|
||||
To first build your image, clone the tritonserver repo.
|
||||
|
||||
```
|
||||
git clone https://github.com/triton-inference-server/server.git
|
||||
```
|
||||
|
||||
then run `compose.py` to build a docker compose file
|
||||
```
|
||||
cd server
|
||||
python3 compose.py --repoagent checksum --dry-run
|
||||
```
|
||||
|
||||
Because dshark is a third party backend, you will need to manually modify the `Dockerfile.compose` to include the dshark backend. To do this, in the Dockerfile.compose file produced, copy this line.
|
||||
the dshark backend will be located in the build folder from earlier under `/build/install/backends`
|
||||
|
||||
```
|
||||
COPY /path/to/build/install/backends/dshark /opt/tritonserver/backends/dshark
|
||||
```
|
||||
|
||||
Next run
|
||||
```
|
||||
docker build -t tritonserver_custom -f Dockerfile.compose .
|
||||
docker run -it --gpus=1 --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
|
||||
```
|
||||
|
||||
where `path/to/model_repos` is where you are storing the models you want to run
|
||||
|
||||
if your not using gpus, omit `--gpus=1`
|
||||
|
||||
```
|
||||
docker run -it --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
|
||||
```
|
||||
|
||||
# Setting up a model
|
||||
|
||||
to include a model in your backend, add a directory with your model name to your model repository directory. examples of models can be seen here: https://github.com/triton-inference-server/backend/tree/main/examples/model_repos/minimal_models
|
||||
|
||||
make sure to adjust the input correctly in the config.pbtxt file, and save a vmfb file under 1/model.vmfb
|
||||
|
||||
# CUDA
|
||||
|
||||
if you're having issues with cuda, make sure your correct drivers are installed, and that `nvidia-smi` works, and also make sure that the nvcc compiler is on the path.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include(CMakeFindDependencyMacro)
|
||||
|
||||
get_filename_component(
|
||||
SHARKBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${SHARKBACKEND_CMAKE_DIR})
|
||||
|
||||
if(NOT TARGET SharkBackend::triton-dshark-backend)
|
||||
include("${SHARKBACKEND_CMAKE_DIR}/SharkBackendTargets.cmake")
|
||||
endif()
|
||||
|
||||
set(SHARKBACKEND_LIBRARIES SharkBackend::triton-dshark-backend)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,30 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
{
|
||||
global:
|
||||
TRITONBACKEND_*;
|
||||
local: *;
|
||||
};
|
||||
1
inference/thirdparty/shark-runtime
vendored
1
inference/thirdparty/shark-runtime
vendored
Submodule inference/thirdparty/shark-runtime deleted from 7b82d90c72
@@ -7,22 +7,13 @@ import fileinput
|
||||
from pathlib import Path
|
||||
|
||||
# Temporary workaround for transformers/__init__.py.
|
||||
path_to_stdhooks = Path(
|
||||
get_python_lib() + "/_pyinstaller_hooks_contrib/hooks/stdhooks"
|
||||
)
|
||||
path_to_transformers_hook = Path(
|
||||
str(path_to_stdhooks) + "hook-transformers.py"
|
||||
get_python_lib()
|
||||
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
|
||||
)
|
||||
if path_to_transformers_hook.is_file():
|
||||
pass
|
||||
else:
|
||||
if not path_to_stdhooks.is_dir():
|
||||
import os
|
||||
|
||||
print(
|
||||
f"Path to pyinstaller stdhooks not found. Please check your pyinstaller packages at {path_to_stdhooks}."
|
||||
)
|
||||
os.mkdir(path_to_stdhooks)
|
||||
with open(path_to_transformers_hook, "w") as f:
|
||||
f.write("module_collection_mode = 'pyz+py'")
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ iree-tools-tf
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tf-nightly
|
||||
keras
|
||||
keras-nightly
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
|
||||
@@ -25,7 +25,7 @@ diffusers
|
||||
accelerate
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
gradio==3.44.3
|
||||
altair
|
||||
omegaconf
|
||||
# 0.3.2 doesn't have binaries for arm64
|
||||
|
||||
@@ -130,14 +130,13 @@ fi
|
||||
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/cpu/
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
|
||||
T_VER=$($PYTHON -m pip show torch | grep Version)
|
||||
TORCH_VERSION=${T_VER:9:17}
|
||||
T_VER_MIN=${T_VER:14:12}
|
||||
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
|
||||
TV_VERSION=${TV_VER:9:18}
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
|
||||
TV_VER_MAJ=${TV_VER:9:6}
|
||||
$PYTHON -m pip uninstall -y torchvision
|
||||
$PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu118."
|
||||
else
|
||||
|
||||
@@ -52,6 +52,8 @@ def iree_device_map(device):
|
||||
)
|
||||
if len(uri_parts) == 1:
|
||||
return iree_driver
|
||||
elif "rocm" in uri_parts:
|
||||
return "rocm"
|
||||
else:
|
||||
return f"{iree_driver}://{uri_parts[1]}"
|
||||
|
||||
@@ -63,7 +65,6 @@ def get_supported_device_list():
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"AMD-AIE": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
@@ -82,7 +83,6 @@ def iree_target_map(device):
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cpu-task": "llvm-cpu",
|
||||
"AMD-AIE": "llvm-cpu",
|
||||
"cpu-sync": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
@@ -121,7 +121,10 @@ def check_device_drivers(device):
|
||||
return False
|
||||
elif device == "rocm":
|
||||
try:
|
||||
subprocess.check_output("rocminfo")
|
||||
if sys.platform == "win32":
|
||||
subprocess.check_output("hipinfo")
|
||||
else:
|
||||
subprocess.check_output("rocminfo")
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import iree._runtime.scripts.iree_benchmark_module as benchmark_module
|
||||
from shark.iree_utils._common import run_cmd, iree_device_map
|
||||
from shark.iree_utils.cpu_utils import get_cpu_count
|
||||
import numpy as np
|
||||
@@ -102,15 +101,13 @@ def build_benchmark_args_non_tensor_input(
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = benchmark_module.__path__[0]
|
||||
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
|
||||
if platform.system() == "Windows":
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module.exe"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
|
||||
time_extractor = None
|
||||
else:
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
if function_name:
|
||||
@@ -135,7 +132,7 @@ def run_benchmark_module(benchmark_cl):
|
||||
benchmark_path = benchmark_cl[0]
|
||||
assert os.path.exists(
|
||||
benchmark_path
|
||||
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
|
||||
), "Cannot find iree_benchmark_module, Please contact SHARK maintainer on discord."
|
||||
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
|
||||
try:
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
|
||||
@@ -46,7 +46,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
|
||||
data_tiling_flag = ["--iree-opt-data-tiling"]
|
||||
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
|
||||
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
|
||||
|
||||
@@ -84,7 +84,7 @@ def get_iree_frontend_args(frontend):
|
||||
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
return [
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-flow-demote-i64-to-i32",
|
||||
"--iree-input-demote-i64-to-i32",
|
||||
]
|
||||
else:
|
||||
# Frontend not found.
|
||||
@@ -92,13 +92,27 @@ def get_iree_frontend_args(frontend):
|
||||
|
||||
|
||||
# Common args to be used given any frontend or device.
|
||||
def get_iree_common_args():
|
||||
return [
|
||||
def get_iree_common_args(debug=False):
|
||||
common_args = [
|
||||
"--iree-stream-resource-max-allocation-size=4294967295",
|
||||
"--iree-vm-bytecode-module-strip-source-map=true",
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
"--iree-opt-strip-assertions=true",
|
||||
]
|
||||
if debug == True:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=false",
|
||||
"--verify=true",
|
||||
]
|
||||
)
|
||||
else:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=true",
|
||||
"--verify=false",
|
||||
]
|
||||
)
|
||||
return common_args
|
||||
|
||||
|
||||
# Args that are suitable only for certain models or groups of models.
|
||||
@@ -277,14 +291,16 @@ def compile_module_to_flatbuffer(
|
||||
model_config_path,
|
||||
extra_args,
|
||||
model_name="None",
|
||||
debug=False,
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args()
|
||||
args += get_iree_common_args(debug=debug)
|
||||
args += get_model_specific_args()
|
||||
args += extra_args
|
||||
args += shark_args.additional_compile_args
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "auto"
|
||||
@@ -342,7 +358,8 @@ def load_vmfb_using_mmap(
|
||||
flatbuffer_blob_or_path, device: str, device_idx: int = None
|
||||
):
|
||||
print(f"Loading module {flatbuffer_blob_or_path}...")
|
||||
|
||||
if "rocm" in device:
|
||||
device = "rocm"
|
||||
with DetailLogger(timeout=2.5) as dl:
|
||||
# First get configs.
|
||||
if device_idx is not None:
|
||||
@@ -387,6 +404,11 @@ def load_vmfb_using_mmap(
|
||||
dl.log(f"mmap {flatbuffer_blob_or_path}")
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
dl.log(f"ireert.SystemContext created")
|
||||
if "vulkan" in device:
|
||||
# Vulkan pipeline creation consumes significant amount of time.
|
||||
print(
|
||||
"\tCompiling Vulkan shaders. This may take a few minutes."
|
||||
)
|
||||
ctx.add_vm_module(mmaped_vmfb)
|
||||
dl.log(f"module initialized")
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
@@ -409,10 +431,11 @@ def get_iree_compiled_module(
|
||||
extra_args: list = [],
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, model_config_path, extra_args
|
||||
module, device, frontend, model_config_path, extra_args, debug
|
||||
)
|
||||
temp_file_to_unlink = None
|
||||
# TODO: Currently mmap=True control flow path has been switched off for mmap.
|
||||
@@ -468,10 +491,11 @@ def export_iree_module_to_vmfb(
|
||||
model_config_path: str = None,
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
debug: bool = False,
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, mlir_dialect, model_config_path, extra_args
|
||||
module, device, mlir_dialect, model_config_path, extra_args, debug
|
||||
)
|
||||
if module_name is None:
|
||||
device_name = (
|
||||
@@ -479,9 +503,9 @@ def export_iree_module_to_vmfb(
|
||||
)
|
||||
module_name = f"{mlir_dialect}_{device_name}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
with open(filename, "wb") as f:
|
||||
f.write(flatbuffer_blob)
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
return filename
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import functools
|
||||
import iree.runtime as ireert
|
||||
import ctypes
|
||||
import sys
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
@@ -42,21 +43,51 @@ def get_iree_gpu_args():
|
||||
@functools.cache
|
||||
def get_iree_rocm_args():
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
# get arch from rocminfo.
|
||||
# get arch from hipinfo.
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
rocm_arch = re.match(
|
||||
r".*(gfx\w+)",
|
||||
subprocess.check_output(
|
||||
"rocminfo | grep -i 'gfx'", shell=True, text=True
|
||||
),
|
||||
).group(1)
|
||||
print(f"Found rocm arch {rocm_arch}...")
|
||||
if sys.platform == "win32":
|
||||
if "HIP_PATH" in os.environ:
|
||||
rocm_path = os.environ["HIP_PATH"]
|
||||
print(f"Found a ROCm installation at {rocm_path}.")
|
||||
else:
|
||||
print("Failed to find ROCM_PATH. Defaulting to C:\\AMD\\ROCM\\5.5")
|
||||
rocm_path = "C:\\AMD\\ROCM\\5.5"
|
||||
else:
|
||||
if "ROCM_PATH" in os.environ:
|
||||
rocm_path = os.environ["ROCM_PATH"]
|
||||
print(f"Found a ROCm installation at {rocm_path}.")
|
||||
else:
|
||||
print("Failed to find ROCM_PATH. Defaulting to /opt/rocm")
|
||||
rocm_path = "/opt/rocm/"
|
||||
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
rocm_arch = re.search(
|
||||
r"gfx\d{3,}",
|
||||
subprocess.check_output("hipinfo", shell=True, text=True),
|
||||
).group(0)
|
||||
else:
|
||||
rocm_arch = re.match(
|
||||
r".*(gfx\w+)",
|
||||
subprocess.check_output(
|
||||
"rocminfo | grep -i 'gfx'", shell=True, text=True
|
||||
),
|
||||
).group(1)
|
||||
print(f"Found rocm arch {rocm_arch}...")
|
||||
except:
|
||||
print(
|
||||
"Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100."
|
||||
)
|
||||
rocm_arch = "gfx1100"
|
||||
|
||||
bc_path = os.path.join(rocm_path, "amdgcn", "bitcode")
|
||||
return [
|
||||
f"--iree-rocm-target-chip={rocm_arch}",
|
||||
"--iree-rocm-link-bc=true",
|
||||
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
|
||||
f"--iree-rocm-bc-dir={bc_path}",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -57,11 +57,8 @@ def get_version(triple):
|
||||
@functools.cache
|
||||
def get_extensions(triple):
|
||||
def make_ext_list(ext_list):
|
||||
res = ""
|
||||
for e in ext_list:
|
||||
res += e + ", "
|
||||
res = f"[{res[:-2]}]"
|
||||
return res
|
||||
res = ", ".join(ext_list)
|
||||
return f"[{res}]"
|
||||
|
||||
arch, product, os = triple
|
||||
if arch == "m1":
|
||||
@@ -119,7 +116,7 @@ def get_extensions(triple):
|
||||
]
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("VK_NV_cooperative_matrix")
|
||||
ext.append("VK_KHR_cooperative_matrix")
|
||||
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
|
||||
ext.append("VK_KHR_shader_integer_dot_product")
|
||||
return make_ext_list(ext_list=ext)
|
||||
@@ -247,7 +244,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>"
|
||||
]
|
||||
|
||||
if product == "rx5700xt":
|
||||
@@ -468,9 +465,9 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
]
|
||||
|
||||
elif arch == "adreno":
|
||||
@@ -531,7 +528,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cmc = ""
|
||||
for case in v:
|
||||
cmc += f"#vk.coop_matrix_props<{case}>, "
|
||||
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
|
||||
res += f"cooperativeMatrixPropertiesKHR = [{cmc[:-2]}], "
|
||||
else:
|
||||
res += f"{k} = {get_comma_sep_str(v)}, "
|
||||
else:
|
||||
|
||||
@@ -23,11 +23,19 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_all_vulkan_devices():
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver("vulkan")
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return [d["name"] for d in device_list_src]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
|
||||
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
@@ -111,6 +119,8 @@ def get_vulkan_target_triple(device_name):
|
||||
# Windows: AMD Radeon RX 7900 XTX
|
||||
elif all(x in device_name for x in ("RX", "7900")):
|
||||
triple = f"rdna3-7900-{system_os}"
|
||||
elif all(x in device_name for x in ("Radeon", "780M")):
|
||||
triple = f"rdna3-780m-{system_os}"
|
||||
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
|
||||
triple = f"rdna3-w7900-{system_os}"
|
||||
elif any(x in device_name for x in ("AMD", "Radeon")):
|
||||
@@ -178,9 +188,7 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
@functools.cache
|
||||
def get_iree_vulkan_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={shark_args.vulkan_large_heap_block_size}",
|
||||
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
|
||||
f"--vulkan_vma_allocator={'true' if shark_args.vulkan_vma_allocator else 'false'}",
|
||||
]
|
||||
return vulkan_runtime_flags
|
||||
|
||||
|
||||
@@ -14,8 +14,21 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
|
||||
class SplitStrToListAction(argparse.Action):
|
||||
def __init__(self, option_strings, dest, *args, **kwargs):
|
||||
super(SplitStrToListAction, self).__init__(
|
||||
option_strings=option_strings, dest=dest, *args, **kwargs
|
||||
)
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
del parser, option_string
|
||||
setattr(namespace, self.dest, shlex.split(values[0]))
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="SHARK runner.")
|
||||
|
||||
parser.add_argument(
|
||||
@@ -24,6 +37,13 @@ parser.add_argument(
|
||||
default="cpu",
|
||||
help="Device on which shark_runner runs. options are cpu, cuda, and vulkan",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional_compile_args",
|
||||
default=list(),
|
||||
nargs=1,
|
||||
action=SplitStrToListAction,
|
||||
help="Additional arguments to pass to the compiler. These are appended as the last arguments.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_tf32",
|
||||
type=bool,
|
||||
@@ -133,13 +153,6 @@ parser.add_argument(
|
||||
help="Profiles vulkan device and collects the .rdc info.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="2073741824",
|
||||
help="Flag for setting VMA preferredLargeHeapBlockSize for "
|
||||
"vulkan device, default is 4G.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
@@ -147,11 +160,4 @@ parser.add_argument(
|
||||
help="Flag for disabling vulkan validation layers when benchmarking.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_vma_allocator",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling / disabling Vulkan VMA Allocator.",
|
||||
)
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
|
||||
@@ -115,7 +115,7 @@ def compile_int_precision(
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
@@ -451,6 +451,65 @@ def transform_fx(fx_g, quantized=False):
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
def gptq_transforms(fx_g):
|
||||
import torch
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.ones,
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("device") == torch.device(device="cuda:0"):
|
||||
updated_kwargs = node.kwargs.copy()
|
||||
updated_kwargs["device"] = torch.device(device="cpu")
|
||||
node.kwargs = updated_kwargs
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("dtype") == torch.bfloat16:
|
||||
updated_kwargs = node.kwargs.copy()
|
||||
updated_kwargs["dtype"] = torch.float16
|
||||
node.kwargs = updated_kwargs
|
||||
|
||||
# Inputs of aten.native_layer_norm should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.native_layer_norm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (
|
||||
new_node_arg0,
|
||||
node.args[1],
|
||||
node.args[2],
|
||||
node.args[3],
|
||||
node.args[4],
|
||||
)
|
||||
|
||||
# Downcasting the result of native_layer_norm back to fp16.
|
||||
if node.name.startswith("getitem"):
|
||||
with fx_g.graph.inserting_before(node):
|
||||
if node.args[0].target in [
|
||||
torch.ops.aten.native_layer_norm
|
||||
]:
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node,),
|
||||
kwargs={"dtype": torch.float32},
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
new_node.kwargs = {"dtype": torch.float32}
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
# Doesn't replace the None type.
|
||||
def change_fx_graph_return_to_tuple(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
@@ -504,27 +563,12 @@ def import_with_fx(
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
precision="fp32",
|
||||
is_gptq=False,
|
||||
):
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
block_quant_layer_level_manager,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
brevitas_layer_export_mode,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
LinearWeightBlockQuantHandlerFwd,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
matmul_rhs_group_quant_placeholder,
|
||||
)
|
||||
from brevitas.backport.fx.experimental.proxy_tensor import (
|
||||
make_fx as brevitas_make_fx,
|
||||
)
|
||||
|
||||
golden_values = None
|
||||
if debug:
|
||||
@@ -596,8 +640,30 @@ def import_with_fx(
|
||||
torch.ops.aten.native_layer_norm,
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
||||
torch.ops.aten.index_add,
|
||||
torch.ops.aten.index_add_,
|
||||
]
|
||||
if precision in ["int4", "int8"]:
|
||||
if precision in ["int4", "int8"] and not is_gptq:
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
block_quant_layer_level_manager,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
brevitas_layer_export_mode,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
LinearWeightBlockQuantHandlerFwd,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
replace_call_fn_target,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
|
||||
matmul_rhs_group_quant_placeholder,
|
||||
)
|
||||
from brevitas.backport.fx.experimental.proxy_tensor import (
|
||||
make_fx as brevitas_make_fx,
|
||||
)
|
||||
|
||||
export_context_manager = brevitas_layer_export_mode
|
||||
export_class = block_quant_layer_level_manager(
|
||||
export_handlers=[LinearWeightBlockQuantHandlerFwd]
|
||||
@@ -647,6 +713,10 @@ def import_with_fx(
|
||||
add_upcast(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if is_gptq:
|
||||
gptq_transforms(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if mlir_type == "fx":
|
||||
return fx_g
|
||||
|
||||
@@ -677,5 +747,5 @@ def import_with_fx(
|
||||
)
|
||||
return mlir_module, func_name
|
||||
|
||||
mlir_module, func_name = mlir_importer.import_mlir()
|
||||
mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type)
|
||||
return mlir_module, func_name
|
||||
|
||||
@@ -192,7 +192,9 @@ class SharkInference:
|
||||
|
||||
# TODO: Instead of passing directory and having names decided by the module
|
||||
# , user may want to save the module with manual names.
|
||||
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
|
||||
def save_module(
|
||||
self, dir=os.getcwd(), module_name=None, extra_args=[], debug=False
|
||||
):
|
||||
return export_iree_module_to_vmfb(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
@@ -200,6 +202,7 @@ class SharkInference:
|
||||
self.mlir_dialect,
|
||||
module_name=module_name,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
# load and return the module.
|
||||
|
||||
@@ -69,7 +69,7 @@ class SharkTrainer:
|
||||
self.frontend = frontend
|
||||
|
||||
# Training function is needed in the case of torch_fn.
|
||||
def compile(self, training_fn=None, extra_args=[]):
|
||||
def compile(self, training_fn=None, mlir_type="linalg", extra_args=[]):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
packed_inputs = (
|
||||
dict(self.model.named_parameters()),
|
||||
@@ -77,7 +77,12 @@ class SharkTrainer:
|
||||
tuple(self.input),
|
||||
)
|
||||
mlir_module, func_name = import_with_fx(
|
||||
training_fn, packed_inputs, False, [], training=True
|
||||
training_fn,
|
||||
packed_inputs,
|
||||
False,
|
||||
[],
|
||||
training=True,
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
self.shark_runner = SharkRunner(
|
||||
mlir_module,
|
||||
|
||||
@@ -59,7 +59,7 @@ def create_module(model_name, tokenizer, device):
|
||||
)
|
||||
|
||||
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
|
||||
shark_module.save_module(module_name=vmfb_name)
|
||||
shark_module.save_module(module_name=vmfb_name, debug=False)
|
||||
vmfb_path = vmfb_name + ".vmfb"
|
||||
return vmfb_path
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ import collections
|
||||
import json
|
||||
import os
|
||||
import psutil
|
||||
import resource
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
@@ -168,7 +167,7 @@ def save_json(data, filename):
|
||||
|
||||
|
||||
def collect_huggingface_logits(
|
||||
model_name: str, max_seq_len: int, save_json: bool
|
||||
model_name: str, max_seq_len: int, to_save_json: bool
|
||||
) -> Tuple[float, float]:
|
||||
# Load
|
||||
t0 = time.time()
|
||||
@@ -194,11 +193,11 @@ def collect_huggingface_logits(
|
||||
for idx, tokens in enumerate(tokenized_prompts):
|
||||
print("prompt: {}".format(PROMPTS[idx]))
|
||||
logits = run_huggingface_model(model_wrapper, tokens)
|
||||
if save_json:
|
||||
if to_save_json:
|
||||
results.append([PROMPTS[idx], logits[0].tolist()])
|
||||
run_time = time.time() - t0
|
||||
print("--- Took {} seconds to run Huggingface.".format(run_time))
|
||||
if save_json:
|
||||
if to_save_json:
|
||||
save_json(results, "/tmp/huggingface.json")
|
||||
run_memory_info = get_memory_info()
|
||||
return {
|
||||
@@ -215,7 +214,10 @@ def collect_huggingface_logits(
|
||||
|
||||
|
||||
def collect_shark_logits(
|
||||
model_name: str, max_seq_len: int, recompile_shark: bool, save_json: bool
|
||||
model_name: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
to_save_json: bool,
|
||||
) -> Tuple[float, float]:
|
||||
# Load
|
||||
t0 = time.time()
|
||||
@@ -246,11 +248,11 @@ def collect_shark_logits(
|
||||
print("prompt: {}".format(PROMPTS[idx]))
|
||||
logits = run_shark_model(model_wrapper, tokens)
|
||||
lst = [e.tolist() for e in logits]
|
||||
if save_json:
|
||||
if to_save_json:
|
||||
results.append([PROMPTS[idx], lst])
|
||||
run_time = time.time() - t0
|
||||
print("--- Took {} seconds to run Shark.".format(run_time))
|
||||
if save_json:
|
||||
if to_save_json:
|
||||
save_json(results, "/tmp/shark.json")
|
||||
platform_postfix = "-compile" if recompile_shark else "-precompiled"
|
||||
run_memory_info = get_memory_info()
|
||||
|
||||
@@ -287,6 +287,9 @@ class SharkModuleTester:
|
||||
repro_path = os.path.join("reproducers", self.tmp_prefix, "*")
|
||||
|
||||
bashCommand = f"gsutil cp -r {repro_path} gs://shark-public/builder/repro_artifacts/{self.ci_sha}/{self.tmp_prefix}/"
|
||||
print(
|
||||
f"Uploading reproducer {repro_path} to gs://shark-public/builder/repro_artifacts/{self.ci_sha}/{self.tmp_prefix}/"
|
||||
)
|
||||
process = subprocess.run(bashCommand.split())
|
||||
|
||||
def postprocess_outputs(self, golden_out, result):
|
||||
|
||||
Reference in New Issue
Block a user