Compare commits

..

69 Commits

Author SHA1 Message Date
Ean Garvey
841773fa32 Updates to opt_causallm example (#1905)
* Updates to opt_causallm example

* Fixup opt_perf_comparison.py

* Use same filenames across opt examples.
2023-10-24 10:54:39 -07:00
Stefan Kapusniak
0361db46f9 SD: Fix unet untuned opt_flags (#1912)
* correct my sloppy copy/paste for the untuned unet default compilation
flags that introduced an extra 'detach' into what should have been
'iree-global-opt-convert-1x1-filter-conv2d-to-matmul'
2023-10-24 12:47:33 -05:00
xzuyn
a012433ffd Save hiresfix info if used (#1914) 2023-10-24 12:45:10 -05:00
xzuyn
5061193da3 Move Generate, Randomize Seed, & Stop Batch to same positions as txt2img (#1915) 2023-10-24 12:44:39 -05:00
xzuyn
bff48924be LLaMa 2 Chat template fix (#1913) 2023-10-23 18:51:15 -05:00
Stefan Kapusniak
825b36cbdd Fix MLIR Textual PassPipeline Error (#1910) 2023-10-22 07:39:52 -07:00
Stefan Kapusniak
134441957d SD - Fix civitai download on Windows +improvements (#1907) 2023-10-21 11:17:41 -07:00
Stefan Kapusniak
7cd14fdc47 SD/UI: Use a single model selection box on UI tabs (#1906)
* Allow entry of a huggingface model id or civitai download url to be
done in the main model selection dropdown on SD tabs
* Remove separate textbox for entering huggingface model id or civitai
download url on SD Tabs
* Remove 'None' option from the model selection dropdown (no longer
needed) on SD tabs
* Update png metadata drop zone on txt2img tab to work with a single
argument for model selection
* Update UI generate functions on SD tabs to work with single argument
model selection
* Update API code for changes to the UI generate functions
* Move info about the custom model path to the logging textarea on SD
tabs
2023-10-21 10:06:05 -07:00
Ean Garvey
e6cb5cef57 Add --additional_runtime_args option and use in OPT example. (#1855)
* Add --additional_runtime_args option and use in OPT example.

Fix the func name. (#1838)

Co-authored-by: Sungsoon Cho <sungsoon.cho@gmail.com>
2023-10-19 13:29:39 -05:00
Huang Qi
66abee8e5b SharkInference: Fix various examples and README.md (#1903)
Follow https://github.com/nod-ai/SHARK/pull/708, remove parameter 'func_name'
for SharkInference.
2023-10-19 09:28:36 -05:00
Ean Garvey
4797bb89f5 Stringify path for ireec.compile_file (#1901)
* Stringify path for ireec.compile_file

* Update test-models.yml
2023-10-18 14:59:23 -05:00
Vivek Khandelwal
205e57683a Modify Falcon-180b-GPTQ sharded pipeline 2023-10-17 20:26:01 +05:30
Vivek Khandelwal
2866d665ee Fix Sharded Falcon-180b-GPTQ Pipeline 2023-10-17 20:26:01 +05:30
Stefan Kapusniak
71d25ec5d8 SD: Fix repeatable seeds when intial seed is random (#1893) 2023-10-14 22:50:42 -07:00
Vivek Khandelwal
202ffff67b Add support for sharded Falcon model 2023-10-13 22:05:10 +05:30
Ean Garvey
0b77059628 Add matmul reassociation flags (#1891) 2023-10-12 20:12:37 -05:00
Stefan Kapusniak
a208302bb9 Fix repeatable seeds consistency over batch counts (#1889)
* Set the input seed for the random number generator when
generating repeatable seeds to exclude any negative numbers
in the parsed seed input.  The makes seeds generated for
different batch counts consistent where they have the same
input for the initial seed or set of seeds.
2023-10-12 17:15:19 -05:00
Vivek Khandelwal
b83d32fafe Fix Falcon GPTQ Pipeline 2023-10-11 20:09:32 +05:30
Vivek Khandelwal
0a618e1863 Add support for Falcon GPTQ 2023-10-11 10:47:48 +05:30
Phaneesh Barwaria
a731eb6ed4 Macos fixes (#1883)
* fix venv setup for MacOS

* allow stream fuse binding on mac

* clean iree metal args
2023-10-09 23:36:12 -07:00
Ean Garvey
2004d16945 Revert "[SDXL] Add SDXL pipeline to SHARK (#1731)" (#1882)
This reverts commit 9f0a421764.
2023-10-09 18:01:44 -07:00
Gaurav Shukla
6e409bfb77 fix else if syntax error
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-10 06:23:56 +05:30
Gaurav Shukla
77727d149c [warning] Fix dropdown warning
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-10 05:18:43 +05:30
Ean Garvey
66f6e79d68 Split CPU/GPU definitions conditionally outside of torch contexts. (#1879) 2023-10-09 16:46:41 -07:00
Ean Garvey
3b825579a7 (LLaMa-2) Point to int4 + f32 acc .mlir for cpu (#1878)
- fixes some issues with non-system prompt invocation

Co-authored-by: Gaurav Shukla <gauravshukla789@gmail.com>
2023-10-09 14:37:35 -05:00
Abhishek Varma
9f0a421764 [SDXL] Add SDXL pipeline to SHARK (#1731)
-- This commit adds SDXL pipeline to SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-10-09 13:01:37 -05:00
Gaurav Shukla
c28682110c [chatbot] Flag to add system prompt
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-09 22:17:39 +05:30
Ean Garvey
caf6cc5d8f Switch most compile flows to use ireec.compile_file. (#1863)
* Switch most compile flows to use ireec.compile_file.

* re-add input type to compile_str path.

* Check if mlir_module exists before checking if it's a path or pyobject.

* Fix some save_dir cases
2023-10-06 23:04:43 -05:00
Ean Garvey
8614a18474 Remove tf dependencies from importer path. (#1874)
* Remove tf dependencies from import path.

* Fix formatting.
2023-10-06 12:27:12 -07:00
Jakub Kuderski
86c1c0c215 Add aggregate statistics to microbenchmark (#1871)
Print averaged results at the end of all iterations. Increase the
default number of iterations to 5.

Example:
```
Number of iterations: 5
Prefill: avg. 0.03 s, stddev 0.00
Decode: avg. 43.34 tokens/s, stdev 0.13
```

Also remove the -2 in the number of generated tokens -- I did not find
any evidence we need it.
2023-10-06 10:03:07 -07:00
Daniel Garvey
8bb364bcb8 enforce fp32 accumulates for cpu (#1873) 2023-10-06 11:34:49 -05:00
Daniel Garvey
7abddd01ec argmax inside model + brevitas pin (#1872) 2023-10-05 20:15:21 -07:00
Abhishek Varma
2a451fa0c7 [Llama2] Add a standalone utility for dynamic and combining IRs
-- This script adds a standalone utility for converting Llama IRs
   to dynamic and combining them as well.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-10-05 20:01:06 +05:30
Jakub Kuderski
9c4610b9da Add microbenchmark mode to vicuna CLI (#1864)
Add flags to enable a non-internactive mode for microbenchmarking llama
models. In this mode, the system and user prompts are specified with CLI
flags, and the number of generated tokens and iterations is fixed.

Also move the stats below the response and trim any response blankspace.
2023-10-05 00:12:08 -04:00
powderluv
a38cc9d216 Update vulkan_utils.py for Radeon 780m igpu (#1866) 2023-10-04 20:33:07 -07:00
Jakub Kuderski
1c382449ec [vulkan] Print note about module load times. NFC. (#1862)
Print a note ahead of a potentially long inactivity to set the right expectations.

Separately, we should add progress to the UI and make this loading faster.
2023-10-03 17:27:27 -04:00
Gaurav Shukla
7cc9b3f8e8 [llama cli] Fix llama cli
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-03 20:39:53 +05:30
Gaurav Shukla
e54517e967 [UI] Disable config generator, lora train and model manager (#1858)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-02 22:34:40 -07:00
Ean Garvey
326327a799 Collect pipeline submodules for diffusers ckpt preprocessing. (#1859) 2023-10-03 00:29:28 -04:00
Ean Garvey
785b65c7b0 Add flag for specifying device-local caching allocator heap key. (#1856) 2023-10-03 00:28:39 -04:00
Sungsoon Cho
0d16c81687 Remove unused import. (#1857) 2023-10-02 11:36:08 -05:00
Vivek Khandelwal
8dd7850c69 Add Falcon-GPTQ support 2023-10-02 16:39:57 +05:30
Gaurav Shukla
e930ba85b4 [os] Remove os dependency from vmfb naming (#1854)
Also fixes a small ui issue for chatbot.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 12:38:17 -05:00
Gaurav Shukla
cd732e7a38 [chatbot] split execution time to prefill and decode
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
Gaurav Shukla
8e0f8b3227 [ui] Update chatbot UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
Gaurav Shukla
b8210ef796 [chatbot] Re-instantiate the chatbot object if device id changes
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
PhaneeshB
94594542a9 remove use of vulkaninfo 2023-09-28 21:57:00 +05:30
Gaurav Shukla
82f833e87d [vulkan] Update vmfb naming
Update vmfb naming for vulkan devices in order to resolve naming
conflicts in the presence of multiple vulkan devices.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-28 14:52:11 +05:30
Vivek Khandelwal
c9d6870105 Modify falcon pipeline for 180b support 2023-09-28 12:39:35 +05:30
Jakub Kuderski
4fec03a6cc [vulkan] Switch from coop matrix NV to KHR (#1848) 2023-09-27 21:43:37 -04:00
harsh-nod
9a27f51378 Deprecate inference directory
This patch removes the inference directory that was no longer being used.
2023-09-27 14:29:00 -07:00
Abhishek Varma
ad1a0f35ff Fix misdirection while saving vmfb
-- Currently SHARK suggests that vmfb has been saved, while
    that is not the case and no vmfb is generated. 
    This creates a misdirection for IR/vmfbs which are of larger
    size.
-- This commit therefore fixes that misdirection.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-27 16:25:29 +05:30
Nelson Sharpe
6773278ec2 Fix checkpoint_path unexpected argument (#1832) 2023-09-24 14:17:52 -07:00
Abhishek Varma
9a0efffcca [Llama2] Fix wrong Vulkan device ID + Add Vulkan compile flags
-- This commit fixes the wrong Vulkan device being selected during
   runtime.
-- It also adds couple of IREE compilation flags to target specific
   Vulkan device.
-- It also changes the Vulkan device listing to be more in tune with
   lowering control flow.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-22 22:24:18 +05:30
gpetters94
61c6f153d9 Switch to keras-nightly to fix a Linux issue (#1835) 2023-09-21 12:33:45 -04:00
Phaneesh Barwaria
effd42e8f5 pin gradio to v3.44.3 2023-09-21 17:33:43 +05:30
Sungsoon Cho
b5fbb1a8a0 Rename the func arg save_json to avoid name collision. (#1837)
* Rename the func arg save_json to avoid name collision.

* black formatted.
2023-09-19 17:29:27 -05:00
Quinn Dawkins
ded74d09cd [vicuna.py] Keep past key values on device (#1836)
The past key values are only used within the models themselves and can
be kept on device. For vulkan int4, this gives 44 tok/s (for the first
prompt) and settles at around 26 tok/s on 7900xtx.
2023-09-19 18:17:41 -04:00
Boian Petkantchin
79267931c1 Add argument --additional_compile_args (#1119)
This allows to pass more arguemnts to the IREE compiler
Example:
python my-app.py --additional_compile_args="--mlir-pretty-debuginfo --mlir-timing"

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-09-19 11:26:03 -05:00
zjgarvey
9eceba69b7 local_tank_cache included into clear_all (#1833) 2023-09-18 00:27:23 -05:00
Ean Garvey
ca609afb6a Update README.md (#1830) 2023-09-14 10:33:57 -05:00
Gaurav Shukla
11bdce9790 [flags] Fix vulkan runtime flags as vma is dropped from iree (#1831) 2023-09-14 08:58:59 -05:00
Ean Garvey
684943a4a6 (SD) Fix tokenizers imports in pyinstaller builds. (#1828)
* Fix tokenizers metadata.

* (SD) Disable VAE lowering configs (rdna3) and add versioned tunings.

* Update sd_annotation.py

* (SD) Add cv2 to spec.

* Update stencil pipeline with the new img2img arg.
2023-09-12 12:23:48 -05:00
PhaneeshB
b817bb8455 add roles for llama2 2023-09-12 10:59:28 +05:30
Ean Garvey
780f520f02 Fix vk.target_env extensions and remove redundant SD imports. (#1826)
* Remove redundant IREE runtime imports.

* Fix vulkan target env extensions.
2023-09-11 13:42:52 -05:00
Dom
c61b6f8d65 Code refactoring (#1817)
* use join

* fix bug

* further code optimizations

---------

Co-authored-by: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com>
2023-09-11 11:30:56 -05:00
Abhishek Varma
c854208d49 [Llama2] Prefetch llama2 tokenizer configs (#1824)
-- This commit prefetches llama2 tokenizer configs from shark_tank.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-08 11:29:54 -07:00
Gaurav Shukla
c5dcfc1f13 [vicuna] Exit when mlir is not present in shark tank (#1825)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-08 10:30:29 -07:00
Abhishek Varma
bde63ee8ae Add logging feature in WebUI (#1821) 2023-09-08 05:48:05 -07:00
79 changed files with 3713 additions and 2768 deletions

View File

@@ -137,7 +137,8 @@ jobs:
source shark.venv/bin/activate
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
# disabled due to a low-visibility memory issue with pytest on macos.
# pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'

5
.gitignore vendored
View File

@@ -182,7 +182,7 @@ generated_imgs/
# Custom model related artefacts
variants.json
models/
/models/
# models folder
apps/stable_diffusion/web/models/
@@ -196,3 +196,6 @@ db_dir_UserData
# Embeded browser cache and other
apps/stable_diffusion/web/EBWebView/
# Llama2 tokenizer configs
llama2_tokenizer_configs/

View File

@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
@@ -254,7 +254,6 @@ if you want to instead incorporate this into a python script, you can pass the `
```
shark_module = SharkInference(
mlir_model,
func_name,
device=args.device,
mlir_dialect="tm_tensor",
dispatch_benchmarks="all",
@@ -297,7 +296,7 @@ torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
from shark.shark_inference import SharkInference
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
shark_module = SharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input))
@@ -320,7 +319,7 @@ mhlo_ir = r"""builtin.module {
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
shark_module.compile()
result = shark_module.forward((arg0, arg1))
```

View File

@@ -20,7 +20,7 @@ import gc
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from apps.stable_diffusion.src import args
# Brevitas
@@ -256,6 +256,11 @@ class H2OGPTSHARKModel(torch.nn.Module):
bytecode = bytecode_stream.getvalue()
del module
bytecode = save_mlir(
bytecode,
model_name=f"h2ogpt_{precision}",
frontend="torch",
)
return bytecode
def forward(self, input_ids, attention_mask):

View File

@@ -0,0 +1,442 @@
from pathlib import Path
import argparse
from argparse import RawTextHelpFormatter
import re, gc
"""
This script can be used as a standalone utility to convert IRs to dynamic + combine them.
Following are the various ways this script can be used :-
a. To convert a single Linalg IR to dynamic IR:
--dynamic --first_ir_path=<PATH TO FIRST IR>
b. To convert two Linalg IRs to dynamic IR:
--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>
c. To combine two Linalg IRs into one:
--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
d. To convert both IRs into dynamic as well as combine the IRs:
--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
NOTE: For dynamic you'll also need to provide the following set of flags:-
i. For First Llama : --dynamic_input_size (DEFAULT: 19)
ii. For Second Llama: --model_name (DEFAULT: llama2_7b)
--precision (DEFAULT: 'int4')
You may use --save_dynamic to also save the dynamic IR in option d above.
Else for option a. and b. the dynamic IR(s) will get saved by default.
"""
def combine_mlir_scripts(
first_vicuna_mlir,
second_vicuna_mlir,
output_name,
return_ir=True,
):
print(f"[DEBUG] combining first and second mlir")
print(f"[DEBUG] output_name = {output_name}")
maps1 = []
maps2 = []
constants = set()
f1 = []
f2 = []
print(f"[DEBUG] processing first vicuna mlir")
first_vicuna_mlir = first_vicuna_mlir.splitlines()
while first_vicuna_mlir:
line = first_vicuna_mlir.pop(0)
if re.search("#map\d*\s*=", line):
maps1.append(line)
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "first_vicuna_forward", line)
f1.append(line)
f1 = f1[:-1]
del first_vicuna_mlir
gc.collect()
for i, map_line in enumerate(maps1):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
maps1[i] = map_line
f1 = [
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
for func_line in f1
]
print(f"[DEBUG] processing second vicuna mlir")
second_vicuna_mlir = second_vicuna_mlir.splitlines()
while second_vicuna_mlir:
line = second_vicuna_mlir.pop(0)
if re.search("#map\d*\s*=", line):
maps2.append(line)
elif "global_seed" in line:
continue
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
f2 = f2[:-1]
del second_vicuna_mlir
gc.collect()
for i, map_line in enumerate(maps2):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
maps2[i] = map_line
f2 = [
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
for func_line in f2
]
module_start = 'module attributes {torch.debug_module_name = "_lambda"} {'
module_end = "}"
global_vars = []
vnames = []
global_var_loading1 = []
global_var_loading2 = []
print(f"[DEBUG] processing constants")
counter = 0
constants = list(constants)
while constants:
constant = constants.pop(0)
vname, vbody = constant.split("=")
vname = re.sub("%", "", vname)
vname = vname.strip()
vbody = re.sub("arith.constant", "", vbody)
vbody = vbody.strip()
if len(vbody.split(":")) < 2:
print(constant)
vdtype = vbody.split(":")[-1].strip()
fixed_vdtype = vdtype
if "c1_i64" in vname:
print(constant)
counter += 1
if counter == 2:
counter = 0
print("detected duplicate")
continue
vnames.append(vname)
if "true" not in vname:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
else:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : i1"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
new_f1, new_f2 = [], []
print(f"[DEBUG] processing f1")
for line in f1:
if "func.func" in line:
new_f1.append(line)
for global_var in global_var_loading1:
new_f1.append(global_var)
else:
new_f1.append(line)
print(f"[DEBUG] processing f2")
for line in f2:
if "func.func" in line:
new_f2.append(line)
for global_var in global_var_loading2:
if (
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
in global_var
):
print(global_var)
new_f2.append(global_var)
else:
new_f2.append(line)
f1 = new_f1
f2 = new_f2
del new_f1
del new_f2
gc.collect()
print(
[
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
for x in [maps1, maps2, global_vars, f1, f2]
]
)
# doing it this way rather than assembling the whole string
# to prevent OOM with 64GiB RAM when encoding the file.
print(f"[DEBUG] Saving mlir to {output_name}")
with open(output_name, "w+") as f_:
f_.writelines(line + "\n" for line in maps1)
f_.writelines(line + "\n" for line in maps2)
f_.writelines(line + "\n" for line in [module_start])
f_.writelines(line + "\n" for line in global_vars)
f_.writelines(line + "\n" for line in f1)
f_.writelines(line + "\n" for line in f2)
f_.writelines(line + "\n" for line in [module_end])
del maps1
del maps2
del module_start
del global_vars
del f1
del f2
del module_end
gc.collect()
if return_ir:
print(f"[DEBUG] Reading combined mlir back in")
with open(output_name, "rb") as f:
return f.read()
def write_in_dynamic_inputs0(module, dynamic_input_size):
print("[DEBUG] writing dynamic inputs to first vicuna")
# Current solution for ensuring mlir files support dynamic inputs
# TODO: find a more elegant way to implement this
new_lines = []
module = module.splitlines()
while module:
line = module.pop(0)
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
continue
new_lines.append(line)
return "\n".join(new_lines)
def write_in_dynamic_inputs1(module, model_name, precision):
print("[DEBUG] writing dynamic inputs to second vicuna")
def remove_constant_dim(line):
if "c19_i64" in line:
line = re.sub("c19_i64", "dim_i64", line)
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
if "x20x" in line or "<20x" in line:
line = re.sub("20x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
if " 20," in line:
line = re.sub(" 20,", " %dimp1,", line)
return line
module = module.splitlines()
new_lines = []
# Using a while loop and the pop method to avoid creating a copy of module
if "llama2_13b" in model_name:
pkv_tensor_shape = "tensor<1x40x?x128x"
elif "llama2_70b" in model_name:
pkv_tensor_shape = "tensor<1x8x?x128x"
else:
pkv_tensor_shape = "tensor<1x32x?x128x"
if precision in ["fp16", "int4", "int8"]:
pkv_tensor_shape += "f16>"
else:
pkv_tensor_shape += "f32>"
while module:
line = module.pop(0)
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
)
continue
if "%c2 = arith.constant 2 : index" in line:
continue
if "%c20_i64 = arith.constant 20 : i64" in line:
new_lines.append("%c1_i64 = arith.constant 1 : i64")
new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
new_lines.append(
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
)
continue
line = remove_constant_dim(line)
new_lines.append(line)
return "\n".join(new_lines)
def save_dynamic_ir(ir_to_save, output_file):
if not ir_to_save:
return
# We only get string output from the dynamic conversion utility.
from contextlib import redirect_stdout
with open(output_file, "w") as f:
with redirect_stdout(f):
print(ir_to_save)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="llama ir utility",
description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n"
+ "\tFollowing are the various ways this script can be used :-\n"
+ "\t\ta. To convert a single Linalg IR to dynamic IR:\n"
+ "\t\t\t--dynamic --first_ir_path=<PATH TO FIRST IR>\n"
+ "\t\tb. To convert two Linalg IRs to dynamic IR:\n"
+ "\t\t\t--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>\n"
+ "\t\tc. To combine two Linalg IRs into one:\n"
+ "\t\t\t--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n"
+ "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n"
+ "\t\t\t--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n\n"
+ "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n"
+ "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n"
+ "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n"
+ "\t\t\t--precision (DEFAULT: 'int4')\n"
+ "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n"
+ "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--precision",
"-p",
default="int4",
choices=["fp32", "fp16", "int8", "int4"],
help="Precision of the concerned IR",
)
parser.add_argument(
"--model_name",
type=str,
default="llama2_7b",
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
help="Specify which model to run.",
)
parser.add_argument(
"--first_ir_path",
default=None,
help="path to first llama mlir file",
)
parser.add_argument(
"--second_ir_path",
default=None,
help="path to second llama mlir file",
)
parser.add_argument(
"--dynamic_input_size",
type=int,
default=19,
help="Specify the static input size to replace with dynamic dim.",
)
parser.add_argument(
"--dynamic",
default=False,
action=argparse.BooleanOptionalAction,
help="Converts the IR(s) to dynamic",
)
parser.add_argument(
"--save_dynamic",
default=False,
action=argparse.BooleanOptionalAction,
help="Save the individual IR(s) after converting to dynamic",
)
parser.add_argument(
"--combine",
default=False,
action=argparse.BooleanOptionalAction,
help="Converts the IR(s) to dynamic",
)
args, unknown = parser.parse_known_args()
dynamic = args.dynamic
combine = args.combine
assert (
dynamic or combine
), "neither `dynamic` nor `combine` flag is turned on"
first_ir_path = args.first_ir_path
second_ir_path = args.second_ir_path
assert first_ir_path or second_ir_path, "no input ir has been provided"
if combine:
assert (
first_ir_path and second_ir_path
), "you will need to provide both IRs to combine"
precision = args.precision
model_name = args.model_name
dynamic_input_size = args.dynamic_input_size
save_dynamic = args.save_dynamic
print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}")
print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}")
if dynamic and not combine:
save_dynamic = True
first_ir = None
first_dynamic_ir_name = None
second_ir = None
second_dynamic_ir_name = None
if first_ir_path:
first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic"
with open(first_ir_path, "r") as f:
first_ir = f.read()
if second_ir_path:
second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic"
with open(second_ir_path, "r") as f:
second_ir = f.read()
if dynamic:
first_ir = (
write_in_dynamic_inputs0(first_ir, dynamic_input_size)
if first_ir
else None
)
second_ir = (
write_in_dynamic_inputs1(second_ir, model_name, precision)
if second_ir
else None
)
if save_dynamic:
save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir")
save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir")
if combine:
combine_mlir_scripts(
first_ir,
second_ir,
f"{model_name}_{precision}.mlir",
return_ir=False,
)

View File

@@ -4,9 +4,12 @@ import re
import gc
from io import BytesIO
from pathlib import Path
from statistics import mean, stdev
from tqdm import tqdm
from typing import List, Tuple
import subprocess
import sys
import time
import torch
import torch_mlir
@@ -41,12 +44,18 @@ from apps.language_models.src.model_wrappers.vicuna_model import (
SecondVicuna13B,
SecondVicuna70B,
)
from apps.language_models.src.model_wrappers.vicuna_model_gpu import (
FirstVicunaGPU,
SecondVicuna7BGPU,
SecondVicuna13BGPU,
SecondVicuna70BGPU,
)
from apps.language_models.utils import (
get_vmfb_from_path,
)
from shark.shark_downloader import download_public_file
from shark.shark_importer import get_f16_inputs
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_inference import SharkInference
@@ -101,7 +110,7 @@ parser.add_argument(
"--download_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="download vmfb from sharktank, system dependent, YMMV",
help="Download vmfb from sharktank, system dependent, YMMV",
)
parser.add_argument(
"--model_name",
@@ -129,6 +138,38 @@ parser.add_argument(
help="Specify target triple for vulkan.",
)
# Microbenchmarking options.
parser.add_argument(
"--enable_microbenchmark",
default=False,
action=argparse.BooleanOptionalAction,
help="Enables the microbenchmarking mode (non-interactive). Uses the system and the user prompt from args.",
)
parser.add_argument(
"--microbenchmark_iterations",
type=int,
default=5,
help="Number of microbenchmark iterations. Default: 5.",
)
parser.add_argument(
"--microbenchmark_num_tokens",
type=int,
default=512,
help="Generate an exact number of output tokens. Default: 512.",
)
parser.add_argument(
"--system_prompt",
type=str,
default="",
help="Specify the system prompt. This is only used with `--enable_microbenchmark`",
)
parser.add_argument(
"--user_prompt",
type=str,
default="Hi",
help="Specify the user prompt. This is only used with `--enable_microbenchmark`",
)
# fmt: off
def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
@@ -381,8 +422,7 @@ class VicunaBase(SharkLLMBase):
if sharded:
output = self.shark_model.forward(input_ids, is_first=is_first)
else:
output = self.shark_model("first_vicuna_forward", (input_ids,))
out_tensor = torch.tensor(output[1:])
output = self.shark_model("first_vicuna_forward", (input_ids,), send_to_host=False)
else:
token = params["token"]
@@ -398,28 +438,32 @@ class VicunaBase(SharkLLMBase):
is_first=is_first,
)
else:
token = token.to(torch.int64).reshape([1, 1])
token = torch.tensor(token).reshape([1, 1])
second_input = (token,) + tuple(past_key_values)
output = self.shark_model(
"second_vicuna_forward", second_input
"second_vicuna_forward", second_input, send_to_host=False
)
if sharded:
_logits = output["logits"]
_past_key_values = output["past_key_values"]
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
elif "cpu" in self.device:
_past_key_values = output[1:]
_token = int(output[0].to_host())
else:
_logits = torch.tensor(output[0])
_past_key_values = torch.tensor(output[1:])
_logits = torch.tensor(output[0].to_host())
_past_key_values = output[1:]
_token = torch.argmax(_logits[:, -1, :], dim=1)
_detok = self.tokenizer.decode(_token, skip_special_tokens=False)
ret_dict = {
"token": _token,
"detok": _detok,
"logits": _logits,
"past_key_values": _past_key_values,
}
if "cpu" not in self.device:
ret_dict["logits"] = _logits
if cli:
print(f" token : {_token} | detok : {_detok}")
@@ -640,9 +684,7 @@ class ShardedVicuna(VicunaBase):
mlir_path = Path(f"lmhead.mlir")
vmfb_path = Path(f"lmhead.vmfb")
if mlir_path.exists():
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
print(f"Found bytecode module at {mlir_path}.")
else:
hidden_states = torch_mlir.TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
@@ -667,12 +709,10 @@ class ShardedVicuna(VicunaBase):
filepath.absolute(),
single_file=True,
)
f_ = open(f"lmhead.mlir", "rb")
bytecode = f_.read()
f_.close()
mlir_path = filepath
shark_module = SharkInference(
bytecode,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
@@ -692,9 +732,7 @@ class ShardedVicuna(VicunaBase):
mlir_path = Path(f"norm.mlir")
vmfb_path = Path(f"norm.vmfb")
if mlir_path.exists():
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
print(f"Found bytecode module at {mlir_path}.")
else:
hidden_states = torch_mlir.TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
@@ -713,12 +751,10 @@ class ShardedVicuna(VicunaBase):
filepath.absolute(),
single_file=True,
)
f_ = open(f"norm.mlir", "rb")
bytecode = f_.read()
f_.close()
mlir_path = filepath
shark_module = SharkInference(
bytecode,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
@@ -738,9 +774,7 @@ class ShardedVicuna(VicunaBase):
mlir_path = Path(f"embedding.mlir")
vmfb_path = Path(f"embedding.vmfb")
if mlir_path.exists():
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
print(f"Found bytecode module at {mlir_path}.")
else:
input_ids = torch_mlir.TensorPlaceholder.like(
input_ids, dynamic_axes=[1]
@@ -764,12 +798,10 @@ class ShardedVicuna(VicunaBase):
filepath.absolute(),
single_file=True,
)
f_ = open(f"embedding.mlir", "rb")
bytecode = f_.read()
f_.close()
mlir_path = filepath
shark_module = SharkInference(
bytecode,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
@@ -1219,7 +1251,9 @@ class UnshardedVicuna(VicunaBase):
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
hf_auth_token: str = None,
max_num_tokens=512,
min_num_tokens=0,
device="cpu",
vulkan_target_triple="",
precision="int8",
vicuna_mlir_path=None,
vicuna_vmfb_path=None,
@@ -1229,6 +1263,7 @@ class UnshardedVicuna(VicunaBase):
download_vmfb=False,
cache_vicunas=False,
extra_args_cmd=[],
device_id=None,
debug=False,
) -> None:
super().__init__(
@@ -1237,10 +1272,6 @@ class UnshardedVicuna(VicunaBase):
max_num_tokens,
extra_args_cmd=extra_args_cmd,
)
if "llama2" in self.model_name and hf_auth_token == None:
raise ValueError(
"HF auth token required. Pass it using --hf_auth_token flag."
)
self.hf_auth_token = hf_auth_token
if self.model_name == "llama2_7b":
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
@@ -1250,7 +1281,10 @@ class UnshardedVicuna(VicunaBase):
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
print(f"[DEBUG] hf model name: {self.hf_model_path}")
self.max_sequence_length = 256
self.min_num_tokens = min_num_tokens
self.device = device
self.vulkan_target_triple = vulkan_target_triple
self.device_id = device_id
self.precision = precision
self.download_vmfb = download_vmfb
self.vicuna_vmfb_path = vicuna_vmfb_path
@@ -1271,17 +1305,32 @@ class UnshardedVicuna(VicunaBase):
safe_device = self.device.split("-")[0]
if suffix in ["mlirbc", "mlir"]:
return Path(f"{self.model_name}_{self.precision}.{suffix}")
target_triple = ""
if self.vulkan_target_triple != "":
target_triple = "_"
target_triple += "_".join(self.vulkan_target_triple.split("-")[:-1])
return Path(
f"{self.model_name}_{self.precision}_{safe_device}.{suffix}"
f"{self.model_name}_{self.precision}_{safe_device}{target_triple}.{suffix}"
)
def get_tokenizer(self):
kwargs = {"use_auth_token": self.hf_auth_token}
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
use_fast=False,
**kwargs,
)
local_tokenizer_path = Path(Path.cwd(), "llama2_tokenizer_configs")
local_tokenizer_path.mkdir(parents=True, exist_ok=True)
tokenizer_files_to_download = [
"config.json",
"special_tokens_map.json",
"tokenizer.model",
"tokenizer_config.json",
]
for tokenizer_file in tokenizer_files_to_download:
download_public_file(
f"gs://shark_tank/llama2_tokenizer/{tokenizer_file}",
Path(local_tokenizer_path, tokenizer_file),
single_file=True,
)
tokenizer = AutoTokenizer.from_pretrained(str(local_tokenizer_path))
return tokenizer
def get_src_model(self):
@@ -1404,16 +1453,18 @@ class UnshardedVicuna(VicunaBase):
single_file=True,
)
self.shark_model = get_vmfb_from_path(
self.vicuna_vmfb_path, self.device, "tm_tensor"
self.vicuna_vmfb_path, self.device, "tm_tensor", self.device_id
)
if self.shark_model is not None:
print(f"[DEBUG] vmfb found at {self.vicuna_vmfb_path.absolute()}")
return
print(f"[DEBUG] vmfb not found")
print(f"[DEBUG] vmfb not found (search path: {self.vicuna_vmfb_path})")
mlir_generated = False
for suffix in ["mlirbc", "mlir"]:
self.vicuna_mlir_path = self.get_model_path(suffix)
if "cpu" in self.device and "llama2_7b" in self.vicuna_mlir_path.name:
self.vicuna_mlir_path = Path("llama2_7b_int4_f32.mlir")
if not self.vicuna_mlir_path.exists() and self.load_mlir_from_shark_tank:
print(
f"Looking into gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}"
@@ -1425,16 +1476,12 @@ class UnshardedVicuna(VicunaBase):
)
if self.vicuna_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
with open(self.vicuna_mlir_path, "rb") as f:
combined_module = f.read()
combined_module = self.vicuna_mlir_path.absolute()
mlir_generated = True
break
if not mlir_generated:
print(f"[DEBUG] mlir not found")
# Disabling this path of IR generation for now as it is broken.
print("Please check if the mlir file is present at the shark tank. Exiting.")
return
print("[DEBUG] generating mlir on device")
# Select a compilation prompt such that the resulting input_ids
@@ -1456,13 +1503,24 @@ class UnshardedVicuna(VicunaBase):
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
if "cpu" in self.device:
model = FirstVicuna(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = FirstVicunaGPU(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
@@ -1494,6 +1552,9 @@ class UnshardedVicuna(VicunaBase):
use_tracing=False,
verbose=False,
)
if self.cache_vicunas:
with open(first_model_path[:-5]+"_torch.mlir", "w+") as f:
f.write(str(first_module))
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
first_module,
@@ -1548,30 +1609,62 @@ class UnshardedVicuna(VicunaBase):
for _ in range(total_tuple)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
if self.model_name == "llama2_13b":
model = SecondVicuna13B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
elif self.model_name == "llama2_70b":
model = SecondVicuna70B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
if "cpu" in self.device:
if self.model_name == "llama2_13b":
model = SecondVicuna13B(
self.hf_model_path,
self.precision,
"fp32",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
elif self.model_name == "llama2_70b":
model = SecondVicuna70B(
self.hf_model_path,
self.precision,
"fp32",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = SecondVicuna7B(
self.hf_model_path,
self.precision,
"fp32",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = SecondVicuna7B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
if self.model_name == "llama2_13b":
model = SecondVicuna13BGPU(
self.hf_model_path,
self.precision,
"fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
elif self.model_name == "llama2_70b":
model = SecondVicuna70BGPU(
self.hf_model_path,
self.precision,
"fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = SecondVicuna7BGPU(
self.hf_model_path,
self.precision,
"fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
@@ -1608,6 +1701,9 @@ class UnshardedVicuna(VicunaBase):
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
if self.cache_vicunas:
with open(second_model_path[:-5]+"_torch.mlir", "w+") as f:
f.write(str(second_module))
run_pipeline_with_repro_report(
second_module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
@@ -1641,6 +1737,12 @@ class UnshardedVicuna(VicunaBase):
second_module,
self.vicuna_mlir_path,
)
combined_module = save_mlir(
combined_module,
model_name="combined_llama",
mlir_dialect="tm_tensor",
dir=self.vicuna_mlir_path,
)
del first_module, second_module
print(self.device)
@@ -1650,6 +1752,7 @@ class UnshardedVicuna(VicunaBase):
mlir_module=combined_module,
device=self.device,
mlir_dialect="tm_tensor",
device_idx=self.device_id
)
path = shark_module.save_module(
self.vicuna_vmfb_path.parent.absolute(),
@@ -1683,39 +1786,46 @@ class UnshardedVicuna(VicunaBase):
res_tokens = []
params = {"prompt": prompt, "is_first": True, "fv": self.shark_model}
prefill_st_time = time.time()
generated_token_op = self.generate_new_token(
params=params, sharded=False, cli=cli
)
prefill_time = time.time() - prefill_st_time
token = generated_token_op["token"]
logits = generated_token_op["logits"]
if "cpu" not in self.device:
logits = generated_token_op["logits"]
pkv = generated_token_op["past_key_values"]
detok = generated_token_op["detok"]
yield detok, ""
yield detok, None, prefill_time
res_tokens.append(token)
if cli:
print(f"Assistant: {detok}", end=" ", flush=True)
for _ in range(self.max_num_tokens - 2):
for idx in range(self.max_num_tokens):
params = {
"token": token,
"is_first": False,
"logits": logits,
"past_key_values": pkv,
"sv": self.shark_model,
}
if "cpu" not in self.device:
params["logits"] = logits
decode_st_time = time.time()
generated_token_op = self.generate_new_token(
params=params, sharded=False, cli=cli
)
decode_time_ms = (time.time() - decode_st_time)*1000
token = generated_token_op["token"]
logits = generated_token_op["logits"]
if "cpu" not in self.device:
logits = generated_token_op["logits"]
pkv = generated_token_op["past_key_values"]
detok = generated_token_op["detok"]
if token == 2:
if token == 2 and idx >= self.min_num_tokens:
break
res_tokens.append(token)
if detok == "<0x0A>":
@@ -1724,10 +1834,10 @@ class UnshardedVicuna(VicunaBase):
else:
if cli:
print(f"{detok}", end=" ", flush=True)
yield detok, ""
yield detok, None, decode_time_ms
res_str = self.decode_tokens(res_tokens)
yield res_str, "formatted"
yield res_str, "formatted", None
def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
@@ -1774,14 +1884,26 @@ start_message = {
def create_prompt(model_name, history):
global start_message
system_message = start_message[model_name]
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[
f"{B_INST} {item[0].strip()} {E_INST} {item[1].strip()} "
for item in history[1:]
]
)
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
else:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
return msg
@@ -1789,11 +1911,37 @@ if __name__ == "__main__":
args, unknown = parser.parse_known_args()
_extra_args = []
# vulkan target triple
if args.iree_vulkan_target_triple != "":
device_id = None
# Process vulkan target triple.
# TODO: This feature should just be in a common utils for other LLMs and in general
# any model run via SHARK for Vulkan backend.
vulkan_target_triple = args.iree_vulkan_target_triple
if vulkan_target_triple != "":
_extra_args.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Step 1. Fetch the device ID.
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
get_vulkan_target_triple
)
vulkaninfo_list = get_all_vulkan_devices()
id = 0
for device in vulkaninfo_list:
target_triple = get_vulkan_target_triple(vulkaninfo_list[id])
if target_triple == vulkan_target_triple:
device_id = id
break
id += 1
assert device_id, f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
# Step 2. Add a few flags targetting specific hardwares.
if "rdna" in vulkan_target_triple:
flags_to_add = [
"--iree-spirv-index-bits=64",
]
_extra_args = _extra_args + flags_to_add
vic = None
if not args.sharded:
@@ -1807,9 +1955,16 @@ if __name__ == "__main__":
if args.vicuna_vmfb_path is None
else Path(args.vicuna_vmfb_path)
)
min_tokens = 0
max_tokens = 512
if args.enable_microbenchmark:
min_tokens = max_tokens = args.microbenchmark_num_tokens
vic = UnshardedVicuna(
model_name=args.model_name,
hf_auth_token=args.hf_auth_token,
max_num_tokens=max_tokens,
min_num_tokens=min_tokens,
device=args.device,
precision=args.precision,
vicuna_mlir_path=vic_mlir_path,
@@ -1819,6 +1974,7 @@ if __name__ == "__main__":
download_vmfb=args.download_vmfb,
cache_vicunas=args.cache_vicunas,
extra_args_cmd=_extra_args,
device_id=device_id
)
else:
if args.config is not None:
@@ -1835,17 +1991,6 @@ if __name__ == "__main__":
weight_group_size=args.weight_group_size,
extra_args_cmd=_extra_args,
)
if args.model_name == "vicuna":
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
else:
system_message = """System: You are a helpful, respectful and honest assistant. Always answer "
as helpfully as possible, while being safe. Your answers should not
include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal
content. Please ensure that your responses are socially unbiased and positive
in nature. If a question does not make any sense, or is not factually coherent,
explain why instead of answering something not correct. If you don't know the
answer to a question, please don't share false information."""
prologue_prompt = "ASSISTANT:\n"
history = []
@@ -1855,12 +2000,55 @@ if __name__ == "__main__":
"llama2_13b": "llama2_13b=>meta-llama/Llama-2-13b-chat-hf",
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
}
iteration = 0
prefill_times = []
avg_decode_speed = []
while True:
# TODO: Add break condition from user input
user_prompt = input("User: ")
history.append([user_prompt, ""])
prompt = create_prompt(args.model_name, history)
for text, msg in vic.generate(prompt, cli=True):
if "formatted" in msg:
print("Response:", text)
iteration += 1
if not args.enable_microbenchmark:
user_prompt = input("User prompt: ")
history.append([user_prompt, ""])
prompt = create_prompt(args.model_name, history)
else:
if iteration > args.microbenchmark_iterations:
break
user_prompt = args.user_prompt
prompt = args.system_prompt + user_prompt
history = [[user_prompt, ""]]
token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in vic.generate(prompt, cli=True):
if msg is None:
if is_first:
prefill_time = exec_time
is_first = False
else:
total_time_ms += exec_time
token_count += 1
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
prefill_times.append(prefill_time)
avg_decode_speed.append(tokens_per_sec)
print("\nResponse:", text.strip())
print(f"\nNum tokens: {token_count}")
print(f"Prefill: {prefill_time:.2f} seconds")
print(f"Decode: {tokens_per_sec:.2f} tokens/s")
else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
)
if args.enable_microbenchmark:
print("\n### Final Statistics ###")
print("Number of iterations:", iteration - 1)
print(f"Prefill: avg. {mean(prefill_times):.2f} s, stdev {stdev(prefill_times):.2f}")
print(f"Decode: avg. {mean(avg_decode_speed):.2f} tokens/s, stdev {stdev(avg_decode_speed):.2f}")

View File

@@ -0,0 +1,147 @@
import torch
from typing import Optional, Tuple
class WordEmbeddingsLayer(torch.nn.Module):
def __init__(self, word_embedding_layer):
super().__init__()
self.model = word_embedding_layer
def forward(self, input_ids):
output = self.model.forward(input=input_ids)
return output
class CompiledWordEmbeddingsLayer(torch.nn.Module):
def __init__(self, compiled_word_embedding_layer):
super().__init__()
self.model = compiled_word_embedding_layer
def forward(self, input_ids):
input_ids = input_ids.detach().numpy()
new_input_ids = self.model("forward", input_ids)
new_input_ids = new_input_ids.reshape(
[1, new_input_ids.shape[0], new_input_ids.shape[1]]
)
return torch.tensor(new_input_ids)
class LNFEmbeddingLayer(torch.nn.Module):
def __init__(self, ln_f):
super().__init__()
self.model = ln_f
def forward(self, hidden_states):
output = self.model.forward(input=hidden_states)
return output
class CompiledLNFEmbeddingLayer(torch.nn.Module):
def __init__(self, ln_f):
super().__init__()
self.model = ln_f
def forward(self, hidden_states):
hidden_states = hidden_states.detach().numpy()
new_hidden_states = self.model("forward", (hidden_states,))
return torch.tensor(new_hidden_states)
class LMHeadEmbeddingLayer(torch.nn.Module):
def __init__(self, embedding_layer):
super().__init__()
self.model = embedding_layer
def forward(self, hidden_states):
output = self.model.forward(input=hidden_states)
return output
class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
def __init__(self, lm_head):
super().__init__()
self.model = lm_head
def forward(self, hidden_states):
hidden_states = hidden_states.detach().numpy()
new_hidden_states = self.model("forward", (hidden_states,))
return torch.tensor(new_hidden_states)
class DecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model):
super().__init__()
self.model = decoder_layer_model
def forward(self, hidden_states, attention_mask):
output = self.model.forward(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
return (output[0], output[1][0], output[1][1])
class CompiledDecoderLayer(torch.nn.Module):
def __init__(self, shark_decoder_layer_module):
super().__init__()
self.model = shark_decoder_layer_module
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
new_hidden_states, pkv1, pkv2 = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
return tuple(
[
torch.tensor(new_hidden_states),
tuple(
[
torch.tensor(pkv1),
torch.tensor(pkv2),
]
),
]
)
class ShardedFalconModel:
def __init__(self, model, layers, word_embeddings, ln_f, lm_head):
super().__init__()
self.model = model
self.model.transformer.h = torch.nn.modules.container.ModuleList(
layers
)
self.model.transformer.word_embeddings = word_embeddings
self.model.transformer.ln_f = ln_f
self.model.lm_head = lm_head
def forward(
self,
input_ids,
attention_mask=None,
):
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
).logits[:, -1, :]

View File

@@ -54,7 +54,6 @@ from apps.language_models.utils import (
)
from shark.shark_downloader import download_public_file
from shark.shark_importer import get_f16_inputs
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers.models.llama.configuration_llama import LlamaConfig

View File

@@ -7,6 +7,7 @@ class FirstVicuna(torch.nn.Module):
self,
model_path,
precision="fp32",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -15,6 +16,9 @@ class FirstVicuna(torch.nn.Module):
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.accumulates = (
torch.float32 if accumulates == "fp32" else torch.float16
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
@@ -29,7 +33,7 @@ class FirstVicuna(torch.nn.Module):
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float16 if precision == "int4" else torch.float32,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
@@ -43,7 +47,9 @@ class FirstVicuna(torch.nn.Module):
def forward(self, input_ids):
op = self.model(input_ids=input_ids, use_cache=True)
return_vals = []
return_vals.append(op.logits)
token = torch.argmax(op.logits[:, -1, :], dim=1)
return_vals.append(token)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
@@ -56,6 +62,7 @@ class SecondVicuna7B(torch.nn.Module):
self,
model_path,
precision="fp32",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -67,6 +74,9 @@ class SecondVicuna7B(torch.nn.Module):
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
self.accumulates = (
torch.float32 if accumulates == "fp32" else torch.float16
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
@@ -78,7 +88,7 @@ class SecondVicuna7B(torch.nn.Module):
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float16 if precision == "int4" else torch.float32,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
@@ -289,7 +299,8 @@ class SecondVicuna7B(torch.nn.Module):
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
token = torch.argmax(op.logits[:, -1, :], dim=1)
return_vals.append(token)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
@@ -302,6 +313,7 @@ class SecondVicuna13B(torch.nn.Module):
self,
model_path,
precision="int8",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -313,6 +325,9 @@ class SecondVicuna13B(torch.nn.Module):
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
self.accumulates = (
torch.float32 if accumulates == "fp32" else torch.float16
)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
@@ -323,7 +338,7 @@ class SecondVicuna13B(torch.nn.Module):
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float16 if precision == "int4" else torch.float32,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
@@ -595,6 +610,7 @@ class SecondVicuna70B(torch.nn.Module):
self,
model_path,
precision="fp32",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -606,6 +622,9 @@ class SecondVicuna70B(torch.nn.Module):
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
self.accumulates = (
torch.float32 if accumulates == "fp32" else torch.float16
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
@@ -617,7 +636,7 @@ class SecondVicuna70B(torch.nn.Module):
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float16,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,15 @@
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
from apps.language_models.src.model_wrappers.falcon_sharded_model import (
WordEmbeddingsLayer,
CompiledWordEmbeddingsLayer,
LNFEmbeddingLayer,
CompiledLNFEmbeddingLayer,
LMHeadEmbeddingLayer,
CompiledLMHeadEmbeddingLayer,
DecoderLayer,
CompiledDecoderLayer,
ShardedFalconModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.utils import (
get_vmfb_from_path,
@@ -7,9 +18,9 @@ from io import BytesIO
from pathlib import Path
from contextlib import redirect_stdout
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
@@ -28,9 +39,11 @@ parser = argparse.ArgumentParser(
description="runs a falcon model",
)
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
parser.add_argument(
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
)
parser.add_argument(
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
@@ -49,7 +62,7 @@ parser.add_argument(
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
@@ -59,13 +72,27 @@ parser.add_argument(
action=argparse.BooleanOptionalAction,
help="Run model in cli mode",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication token for falcon-180B model.",
)
parser.add_argument(
"-s",
"--sharded",
default=False,
action=argparse.BooleanOptionalAction,
help="Run model as sharded",
)
class Falcon(SharkLLMBase):
class ShardedFalcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_auth_token: str = None,
max_num_tokens=150,
device="cuda",
precision="fp32",
@@ -74,6 +101,24 @@ class Falcon(SharkLLMBase):
debug=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
print("hf_model_path: ", self.hf_model_path)
if "40b" in self.model_name:
raise NotImplementedError(
"Sharded Falcon not supported for 40b variant"
)
if (
"180b" in self.model_name
and precision != "int4"
and hf_auth_token == None
):
raise ValueError(
""" HF auth token required for falcon-180b. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
self.precision = precision
@@ -81,12 +126,14 @@ class Falcon(SharkLLMBase):
self.falcon_mlir_path = falcon_mlir_path
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.src_model = self.get_src_model()
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, trust_remote_code=True
self.hf_model_path,
trust_remote_code=True,
token=self.hf_auth_token,
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 11
@@ -94,13 +141,484 @@ class Falcon(SharkLLMBase):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
kwargs = {
"torch_dtype": torch.float,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile_falcon(self):
def compile_layer(self, layer, falconCompileInput, layer_id):
self.falcon_mlir_path = Path(
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir"
)
self.falcon_vmfb_path = Path(
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}_{self.device}.vmfb"
)
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
print(f"[DEBUG] Trying to download vmfb from shark_tank")
download_public_file(
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/vmfb/"
+ str(self.falcon_vmfb_path),
self.falcon_vmfb_path.absolute(),
single_file=True,
)
vmfb = get_vmfb_from_path(
self.falcon_vmfb_path, self.device, "linalg"
)
if vmfb is not None:
return vmfb
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
if self.falcon_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
print(
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
)
if args.load_mlir_from_shark_tank:
# Downloading MLIR from shark_tank
print(f"[DEBUG] Trying to download mlir from shark_tank")
download_public_file(
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/mlir/"
+ str(self.falcon_mlir_path),
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
print(
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
)
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
if not mlir_generated:
print(f"[DEBUG] generating MLIR locally")
if layer_id == "word_embeddings":
f16_input_mask = [False]
elif layer_id in ["ln_f", "lm_head"]:
f16_input_mask = [True]
elif type(layer_id) == int:
f16_input_mask = [True, False]
else:
raise ValueError("Unsupported layer: ", layer_id)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
layer,
falconCompileInput,
is_f16=True,
f16_input_mask=f16_input_mask,
mlir_type="torchscript",
is_gptq=True,
)
del layer
print(f"[DEBUG] generating torch mlir")
module = torch_mlir.compile(
ts_graph,
falconCompileInput,
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module
f_ = open(self.falcon_mlir_path, "wb")
f_.write(bytecode)
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
f_.close()
del bytecode
shark_module = SharkInference(
mlir_module=self.falcon_mlir_path,
device=self.device,
mlir_dialect="linalg",
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
+ [
"--iree-llvmcpu-use-fast-min-max-ops",
]
if self.precision == "int4"
else [],
debug=self.debug,
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile(self):
sample_input_ids = torch.zeros([100], dtype=torch.int64)
sample_attention_mask = torch.zeros(
[1, 1, 100, 100], dtype=torch.float32
)
if "7b" in self.model_name:
num_in_features = 4544
else:
num_in_features = 14848
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
sample_hidden_states = torch.zeros(
[1, 100, num_in_features], dtype=torch.float32
)
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
print("Compiling Layer lm_head")
shark_lm_head = self.compile_layer(
lm_head, [sample_hidden_states], "lm_head"
)
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
word_embedding = WordEmbeddingsLayer(
self.src_model.transformer.word_embeddings
)
print("Compiling Layer word_embeddings")
shark_word_embedding = self.compile_layer(
word_embedding, [sample_input_ids], "word_embeddings"
)
shark_word_embedding = CompiledWordEmbeddingsLayer(
shark_word_embedding
)
ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f)
print("Compiling Layer ln_f")
shark_ln_f = self.compile_layer(ln_f, [sample_hidden_states], "ln_f")
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
shark_layers = []
for i in range(len(self.src_model.transformer.h)):
print("Compiling Layer {}".format(i))
layer_i = self.src_model.transformer.h[i]
pytorch_layer_i = DecoderLayer(layer_i)
shark_module = self.compile_layer(
pytorch_layer_i,
[sample_hidden_states, sample_attention_mask],
i,
)
shark_layer_i = CompiledDecoderLayer(shark_module)
shark_layers.append(shark_layer_i)
sharded_model = ShardedFalconModel(
self.src_model,
shark_layers,
shark_word_embedding,
shark_ln_f,
shark_lm_head,
)
return sharded_model
def generate(self, prompt):
model_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.max_padding_length,
add_special_tokens=False,
return_tensors="pt",
)
model_inputs["prompt_text"] = prompt
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
generate_kwargs = {
"max_length": self.max_num_tokens,
"do_sample": True,
"top_k": 10,
"num_return_sequences": 1,
"eos_token_id": 11,
}
generate_kwargs["input_ids"] = input_ids
generate_kwargs["attention_mask"] = attention_mask
generation_config_ = GenerationConfig.from_model_config(
self.src_model.config
)
generation_config = copy.deepcopy(generation_config_)
model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
stopping_criteria = StoppingCriteriaList()
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
model_kwargs,
) = self.src_model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
self.logits_processor = self.src_model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids.shape[-1],
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.src_model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
)
self.logits_warper = self.src_model._get_logits_warper(
generation_config
)
(
self.input_ids,
self.model_kwargs,
) = self.src_model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
**model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id) if eos_token_id is not None else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
output_attentions = generation_config.output_attentions # False
output_hidden_states = generation_config.output_hidden_states # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
all_text = prompt
for i in range(self.max_num_tokens - 1):
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
all_text = all_text + new_word
print(f"{new_word}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(input_ids, self.scores)
):
break
torch.cuda.empty_cache()
gc.collect()
return all_text
def generate_new_token(self):
model_inputs = self.src_model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = self.shark_model.forward(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
)
if self.precision in ["fp16", "int4"]:
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
class UnshardedFalcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_auth_token: str = "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",
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=None,
falcon_vmfb_path=None,
debug=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
print("hf_model_path: ", self.hf_model_path)
if "180b" in self.model_name and hf_auth_token == None:
raise ValueError(
""" HF auth token required for falcon-180b. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.src_model = self.get_src_model()
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
trust_remote_code=True,
token=self.hf_auth_token,
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 11
return tokenizer
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile(self):
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
@@ -122,37 +640,37 @@ class Falcon(SharkLLMBase):
if vmfb is not None:
return vmfb
print(
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
)
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
if self.falcon_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
print(
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
if args.load_mlir_from_shark_tank:
# Downloading MLIR from shark_tank
print(f"[DEBUG] Trying to download mlir from shark_tank")
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
print(
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
)
mlir_generated = True
if not mlir_generated:
print(f"[DEBUG] generating MLIR locally")
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 100)
)
@@ -169,9 +687,10 @@ class Falcon(SharkLLMBase):
ts_graph = import_with_fx(
model,
falconCompileInput,
is_f16=self.precision == "fp16",
is_f16=self.precision in ["fp16", "int4"],
f16_input_mask=[False, False],
mlir_type="torchscript",
is_gptq=self.precision == "int4",
)
del model
print(f"[DEBUG] generating torch mlir")
@@ -191,25 +710,30 @@ class Falcon(SharkLLMBase):
bytecode = bytecode_stream.getvalue()
del module
print(f"[DEBUG] writing mlir to file")
with open(f"{self.model_name}.mlir", "wb") as f_:
with redirect_stdout(f_):
print(module.operation.get_asm())
f_ = open(self.falcon_mlir_path, "wb")
f_.write(bytecode)
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
f_.close()
del bytecode
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
mlir_module=self.falcon_mlir_path,
device=self.device,
mlir_dialect="linalg",
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-spirv-index-bits=64",
],
]
+ [
"--iree-llvmcpu-use-fast-min-max-ops",
]
if self.precision == "int4"
else [],
debug=self.debug,
)
print("Saved falcon vmfb at ", str(path))
@@ -217,10 +741,6 @@ class Falcon(SharkLLMBase):
return shark_module
def compile(self):
falcon_shark_model = self.compile_falcon()
return falcon_shark_model
def generate(self, prompt):
model_inputs = self.tokenizer(
prompt,
@@ -390,7 +910,7 @@ class Falcon(SharkLLMBase):
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
if self.precision == "fp16":
if self.precision in ["fp16", "int4"]:
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
@@ -469,16 +989,39 @@ if __name__ == "__main__":
else Path(args.falcon_vmfb_path)
)
falcon = Falcon(
"falcon_" + args.falcon_variant_to_use,
hf_model_path="tiiuae/falcon-"
+ args.falcon_variant_to_use
+ "-instruct",
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
falcon_vmfb_path=falcon_vmfb_path,
)
if args.precision == "int4":
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
else:
hf_model_path_value = (
"TheBloke/falcon-"
+ args.falcon_variant_to_use
+ "-instruct-GPTQ"
)
else:
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "tiiuae/falcon-180B-chat"
else:
hf_model_path_value = (
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
)
if not args.sharded:
falcon = UnshardedFalcon(
model_name="falcon_" + args.falcon_variant_to_use,
hf_model_path=hf_model_path_value,
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
falcon_vmfb_path=falcon_vmfb_path,
)
else:
falcon = ShardedFalcon(
model_name="falcon_" + args.falcon_variant_to_use,
hf_model_path=hf_model_path_value,
device=args.device,
precision=args.precision,
)
import gc
@@ -500,7 +1043,11 @@ if __name__ == "__main__":
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)
res_str = falcon.generate(prompt)
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
User: {prompt}
Assistant:"""
res_str = falcon.generate(prompt_template)
torch.cuda.empty_cache()
gc.collect()
print(

View File

@@ -126,7 +126,7 @@ def is_url(input_url):
import os
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
import torch
import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
@@ -235,6 +235,12 @@ def compile_int_precision(
mlir_module = BytesIO(mlir_module)
bytecode = mlir_module.read()
print(f"Elided IR written for {extended_model_name}")
bytecode = save_mlir(
bytecode,
model_name=extended_model_name,
frontend="torch",
dir=os.getcwd(),
)
return bytecode
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"

View File

@@ -8,7 +8,7 @@ from shark.shark_downloader import download_public_file
# expects a Path / str as arg
# returns None if path not found or SharkInference module
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
if not isinstance(vmfb_path, Path):
vmfb_path = Path(vmfb_path)
@@ -20,7 +20,7 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
print("Loading vmfb from: ", vmfb_path)
print("Device from get_vmfb_from_path - ", device)
shark_module = SharkInference(
None, device=device, mlir_dialect=mlir_dialect
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
@@ -28,7 +28,13 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
def get_vmfb_from_config(
shark_container, model, precision, device, vmfb_path, padding=None
shark_container,
model,
precision,
device,
vmfb_path,
padding=None,
device_id=None,
):
vmfb_url = (
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
@@ -37,4 +43,6 @@ def get_vmfb_from_config(
vmfb_url = vmfb_url + f"_{padding}"
vmfb_url = vmfb_url + ".vmfb"
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
return get_vmfb_from_path(vmfb_path, device, "tm_tensor")
return get_vmfb_from_path(
vmfb_path, device, "tm_tensor", device_id=device_id
)

View File

@@ -15,8 +15,8 @@ pathex = [
# datafiles for pyinstaller
datas = []
datas += collect_data_files("torch")
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
@@ -31,18 +31,17 @@ datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("pytorch_lightning")
datas += collect_data_files("opencv_python")
datas += collect_data_files("skimage")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree")
datas += collect_data_files("google_cloud_storage")
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
@@ -53,6 +52,7 @@ datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
datas += collect_data_files("cv2")
datas += [
("src/utils/resources/prompts.json", "resources"),
("src/utils/resources/model_db.json", "resources"),
@@ -74,6 +74,9 @@ datas += [
# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("diffusers") if "tests" not in x
]
blacklist = ["tests", "convert"]
hiddenimports += [
x
@@ -81,4 +84,4 @@ hiddenimports += [
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree._runtime_libs"]
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]

View File

@@ -8,6 +8,7 @@ import traceback
import subprocess
import sys
import os
import requests
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
@@ -16,6 +17,7 @@ from apps.stable_diffusion.src.utils import (
preprocessCKPT,
convert_original_vae,
get_path_to_diffusers_checkpoint,
get_civitai_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
@@ -94,21 +96,19 @@ class SharkifyStableDiffusionModel:
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.custom_weights = custom_weights.strip()
self.use_quantize = use_quantize
if custom_weights != "":
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = (
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
)
if not os.path.isfile(weights_path):
subprocess.run(
["wget", custom_weights, "-O", weights_path]
)
if custom_weights.startswith("https://civitai.com/api/"):
# download the checkpoint from civitai if we don't already have it
weights_path = get_civitai_checkpoint(custom_weights)
# act as if we were given the local file as custom_weights originally
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
# needed to ensure webui sets the correct model name metadata
args.ckpt_loc = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
@@ -116,6 +116,7 @@ class SharkifyStableDiffusionModel:
custom_weights = get_path_to_diffusers_checkpoint(
custom_weights
)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
@@ -710,8 +711,11 @@ class SharkifyStableDiffusionModel:
return self.text_encoder(input)[0]
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
save_dir = ""
if self.debug:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["clip"]
)
os.makedirs(
save_dir,
exist_ok=True,

View File

@@ -273,6 +273,7 @@ class StencilPipeline(StableDiffusionPipeline):
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.

View File

@@ -84,9 +84,6 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
def _import(self):
scaling_model = ScalingModel()

View File

@@ -41,3 +41,4 @@ from apps.stable_diffusion.src.utils.utils import (
resize_stencil,
_compile_module,
)
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint

View File

@@ -0,0 +1,42 @@
import re
import requests
from apps.stable_diffusion.src.utils.stable_args import args
from pathlib import Path
from tqdm import tqdm
def get_civitai_checkpoint(url: str):
with requests.get(url, allow_redirects=True, stream=True) as response:
response.raise_for_status()
# civitai api returns the filename in the content disposition
base_filename = re.findall(
'"([^"]*)"', response.headers["Content-Disposition"]
)[0]
destination_path = (
Path.cwd() / (args.ckpt_dir or "models") / base_filename
)
# we don't have this model downloaded yet
if not destination_path.is_file():
print(
f"downloading civitai model from {url} to {destination_path}"
)
size = int(response.headers["content-length"], 0)
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
with open(destination_path, "wb") as f:
for chunk in response.iter_content(chunk_size=65536):
f.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
# we already have this model downloaded
else:
print(f"civitai model already downloaded to {destination_path}")
response.close()
return destination_path.as_posix()

View File

@@ -11,12 +11,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
@@ -28,7 +28,7 @@
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
},
@@ -37,7 +37,7 @@
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
@@ -45,12 +45,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}

View File

@@ -158,9 +158,9 @@ def load_lower_configs(base_model_id=None):
f"{spec}.json"
)
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading lowering config file from ", lowering_config_dir)
full_gs_url = config_bucket + config_name
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
@@ -203,8 +203,8 @@ def dump_after_mlir(input_mlir, use_winograd):
if use_winograd:
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))"
@@ -212,8 +212,8 @@ def dump_after_mlir(input_mlir, use_winograd):
else:
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
)
@@ -281,13 +281,9 @@ def sd_model_annotation(mlir_model, model_name, base_model_id=None):
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model = annotate_with_winograd(
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
winograd_model, lowering_config_dir, model_name, use_winograd
)
else:
tuned_model = mlir_model
else:

View File

@@ -458,6 +458,14 @@ p.add_argument(
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
p.add_argument(
"--device_allocator_heap_key",
type=str,
default="",
help="Specify heap key for device caching allocator."
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################
@@ -633,6 +641,13 @@ p.add_argument(
help="Flag for enabling rest API.",
)
p.add_argument(
"--debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling debugging log in WebUI.",
)
p.add_argument(
"--output_gallery",
default=True,

View File

@@ -18,7 +18,7 @@ import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
@@ -154,8 +154,8 @@ def compile_through_fx(
f16_input_mask=f16_input_mask,
debug=debug,
model_name=extended_model_name,
save_dir=save_dir,
)
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
@@ -168,6 +168,14 @@ def compile_through_fx(
mlir_module, extended_model_name, base_model_id
)
if not os.path.isdir(save_dir):
save_dir = ""
mlir_module = save_mlir(
mlir_module,
model_name=extended_model_name,
dir=save_dir,
)
shark_module = SharkInference(
mlir_module,
device=args.device if device is None else device,
@@ -179,17 +187,22 @@ def compile_through_fx(
mlir_module,
)
del mlir_module
gc.collect()
def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if args.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={args.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
@@ -470,7 +483,18 @@ def get_available_devices():
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
@@ -511,10 +535,6 @@ def get_opt_flags(model, precision="fp16"):
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
"default_compilation_flags"
@@ -577,7 +597,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
)
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
checkpoint_path_or_dict=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
@@ -784,11 +804,12 @@ def batch_seeds(
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
if repeatable:
# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
if all(seed < 0 for seed in seeds):
seeds[0] = sanitize_seed(seeds[0])
seed_random(str(seeds))
# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
seed_random(str([n for n in seeds if n > -1]))
# generate any seeds that are unspecified
seeds = [sanitize_seed(seed) for seed in seeds]
@@ -827,6 +848,8 @@ def clear_all():
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
if args.local_tank_cache != "":
shutil.rmtree(args.local_tank_cache)
def get_generated_imgs_path() -> Path:
@@ -872,6 +895,13 @@ def save_output_img(output_img, img_seed, extra_info=None):
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
# Using a conditional expression caused problems, so setting a new
# variable for now.
if args.use_hiresfix:
png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}"
else:
png_size_text = f"{args.width}x{args.height}"
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}"
@@ -880,7 +910,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
f"Sampler: {args.scheduler}, "
f"CFG scale: {args.guidance_scale}, "
f"Seed: {img_seed},"
f"Size: {args.width}x{args.height}, "
f"Size: {png_size_text}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_lora}",
@@ -907,8 +937,10 @@ def save_output_img(output_img, img_seed, extra_info=None):
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"HEIGHT": args.height
if not args.use_hiresfix
else args.hiresfix_height,
"WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
"VAE": img_vae,
@@ -946,6 +978,10 @@ def get_generation_text_info(seeds, device):
)
text_output += (
f"\nsize={args.height}x{args.width}, "
if not args.use_hiresfix
else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, "
)
text_output += (
f"batch_count={args.batch_count}, "
f"batch_size={args.batch_size}, "
f"max_length={args.max_length}"

View File

@@ -1,6 +1,7 @@
from multiprocessing import Process, freeze_support
import os
import sys
import logging
if sys.platform == "darwin":
# import before IREE to avoid torch-MLIR library issues
@@ -41,6 +42,8 @@ def launch_app(address):
if __name__ == "__main__":
if args.debug:
logging.basicConfig(level=logging.DEBUG)
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
if args.api or "api" in args.ui.split(","):
@@ -107,7 +110,6 @@ if __name__ == "__main__":
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_png_info_img,
txt2img_status,
@@ -119,7 +121,6 @@ if __name__ == "__main__":
# h2ogpt_web,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_status,
@@ -128,7 +129,6 @@ if __name__ == "__main__":
img2img_sendto_upscaler,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_status,
@@ -137,7 +137,6 @@ if __name__ == "__main__":
inpaint_sendto_upscaler,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_status,
@@ -146,16 +145,15 @@ if __name__ == "__main__":
outpaint_sendto_upscaler,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_status,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
lora_train_web,
model_web,
model_config_web,
# lora_train_web,
# model_web,
# model_config_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
@@ -247,16 +245,16 @@ if __name__ == "__main__":
upscaler_status,
]
)
with gr.TabItem(label="Model Manager", id=6):
model_web.render()
with gr.TabItem(label="LoRA Training (Experimental)", id=7):
lora_train_web.render()
with gr.TabItem(label="Chat Bot (Experimental)", id=8):
# with gr.TabItem(label="Model Manager", id=6):
# model_web.render()
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
# lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=8):
stablelm_chat.render()
with gr.TabItem(
label="Generate Sharding Config (Experimental)", id=9
):
model_config_web.render()
# with gr.TabItem(
# label="Generate Sharding Config (Experimental)", id=9
# ):
# model_config_web.render()
with gr.TabItem(label="MultiModal (Experimental)", id=10):
minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
@@ -396,31 +394,31 @@ if __name__ == "__main__":
modelmanager_sendto_txt2img,
0,
[hf_models],
[txt2img_custom_model, txt2img_hf_model_id, tabs],
[txt2img_custom_model, tabs],
)
register_modelmanager_button(
modelmanager_sendto_img2img,
1,
[hf_models],
[img2img_custom_model, img2img_hf_model_id, tabs],
[img2img_custom_model, tabs],
)
register_modelmanager_button(
modelmanager_sendto_inpaint,
2,
[hf_models],
[inpaint_custom_model, inpaint_hf_model_id, tabs],
[inpaint_custom_model, tabs],
)
register_modelmanager_button(
modelmanager_sendto_outpaint,
3,
[hf_models],
[outpaint_custom_model, outpaint_hf_model_id, tabs],
[outpaint_custom_model, tabs],
)
register_modelmanager_button(
modelmanager_sendto_upscaler,
4,
[hf_models],
[upscaler_custom_model, upscaler_hf_model_id, tabs],
[upscaler_custom_model, tabs],
)
sd_web.queue()

View File

@@ -3,7 +3,6 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_api,
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_png_info_img,
txt2img_status,
@@ -17,7 +16,6 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_api,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_status,
@@ -30,7 +28,6 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
inpaint_api,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_status,
@@ -43,7 +40,6 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
outpaint_api,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_status,
@@ -56,7 +52,6 @@ from apps.stable_diffusion.web.ui.upscaler_ui import (
upscaler_api,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_status,

View File

@@ -212,6 +212,7 @@ with gr.Blocks(title="DocuChat") as h2ogpt_web:
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
)
precision = gr.Radio(
label="Precision",

View File

@@ -55,8 +55,7 @@ def img2img_inf(
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -103,21 +102,17 @@ def img2img_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files():
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -334,8 +329,7 @@ def img2img_api(
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
@@ -382,31 +376,19 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
i2i_model_info = (str(get_custom_model_path())).replace(
"\\", "\n\\"
i2i_model_info = (
f"Custom Model Path: {str(get_custom_model_path())}"
)
i2i_model_info = f"Custom Model Path: {i2i_model_info}"
img2img_custom_model = gr.Dropdown(
label=f"Models",
info=i2i_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
img2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
choices=get_custom_model_files() + predefined_models,
allow_custom_value=True,
scale=2,
)
# janky fix for overflowing text
i2i_vae_info = (str(get_custom_model_path("vae"))).replace(
@@ -421,6 +403,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -452,6 +436,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Stencil model",
value="None",
choices=["None", "canny", "openpose", "scribble"],
allow_custom_value=True,
)
def show_canvas(choice):
@@ -512,6 +497,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
).replace("\\", "\n\\")
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
lora_weights = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=i2i_lora_info,
elem_id="lora_weights",
@@ -535,6 +521,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -590,6 +577,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
"Cubic",
],
label="Resample Type",
allow_custom_value=True,
)
ondemand = gr.Checkbox(
value=args.ondemand,
@@ -648,17 +636,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -670,13 +649,26 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{i2i_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
lines=2,
elem_id="std_output",
show_label=False,
)
img2img_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
img2img_sendto_outpaint = gr.Button(
@@ -702,7 +694,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
batch_size,
scheduler,
img2img_custom_model,
img2img_hf_model_id,
custom_vae,
precision,
device,

View File

@@ -53,8 +53,7 @@ def inpaint_inf(
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -89,21 +88,17 @@ def inpaint_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"):
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -282,8 +277,7 @@ def inpaint_api(
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
@@ -327,34 +321,21 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Row():
# janky fix for overflowing text
inpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
inpaint_model_info = (
f"Custom Model Path: {inpaint_model_info}"
f"Custom Model Path: {str(get_custom_model_path())}"
)
inpaint_custom_model = gr.Dropdown(
label=f"Models",
info=inpaint_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files(
choices=get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
)
inpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
allow_custom_value=True,
scale=2,
)
# janky fix for overflowing text
inpaint_vae_info = (
@@ -369,6 +350,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -406,6 +389,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -424,6 +408,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -527,17 +512,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -549,14 +525,26 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{inpaint_model_info}\n"
"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
lines=2,
elem_id="std_output",
show_label=False,
)
inpaint_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
inpaint_sendto_outpaint = gr.Button(
@@ -583,7 +571,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
batch_size,
scheduler,
inpaint_custom_model,
inpaint_hf_model_id,
custom_vae,
precision,
device,

View File

@@ -50,6 +50,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
allow_custom_value=True,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -73,6 +74,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -105,6 +107,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
label="Scheduler",
value=args.scheduler,
choices=scheduler_list,
allow_custom_value=True,
)
with gr.Row():
height = gr.Slider(
@@ -177,6 +180,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
with gr.Column(scale=2):

View File

@@ -143,6 +143,7 @@ with gr.Blocks() as minigpt4_web:
# else "Only CUDA Supported for now",
choices=["cuda"],
interactive=False,
allow_custom_value=True,
)
with gr.Column():

View File

@@ -98,6 +98,7 @@ with gr.Blocks() as model_web:
choices=None,
value=None,
visible=False,
allow_custom_value=True,
)
# TODO: select and SendTo
civit_models = gr.Gallery(

View File

@@ -53,8 +53,7 @@ def outpaint_inf(
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -88,21 +87,17 @@ def outpaint_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"):
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -289,8 +284,7 @@ def outpaint_api(
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
@@ -332,36 +326,22 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
outpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
outpaint_model_info = (
f"Custom Model Path: {outpaint_model_info}"
f"Custom Model Path: {str(get_custom_model_path())}"
)
outpaint_custom_model = gr.Dropdown(
label=f"Models",
info=outpaint_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files(
choices=get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
)
outpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
allow_custom_value=True,
scale=2,
)
# janky fix for overflowing text
outpaint_vae_info = (
@@ -376,8 +356,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
@@ -411,6 +392,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -429,6 +411,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -555,17 +538,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -577,13 +551,26 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{outpaint_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
lines=2,
elem_id="std_output",
show_label=False,
)
outpaint_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -611,7 +598,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
batch_size,
scheduler,
outpaint_custom_model,
outpaint_hf_model_id,
custom_vae,
precision,
device,

View File

@@ -109,6 +109,7 @@ with gr.Blocks() as outputgallery_web:
value="",
interactive=True,
elem_classes="dropdown_no_container",
allow_custom_value=True,
)
with gr.Column(
scale=1,

View File

@@ -8,7 +8,7 @@ from transformers import (
from apps.stable_diffusion.web.ui.utils import available_devices
from datetime import datetime as dt
import json
import time
import sys
def user(message, history):
@@ -32,62 +32,73 @@ model_map = {
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2_7b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_13b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_70b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant gives helpful, detailed, and "
"polite answers to the user's questions.\n"
),
}
def create_prompt(model_name, history):
system_message = start_message[model_name]
def create_prompt(model_name, history, prompt_prefix):
system_message = ""
if prompt_prefix:
system_message = start_message[model_name]
if model_name in [
"vicuna",
"llama2_7b",
"llama2_13b",
"llama2_70b",
]:
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
if prompt_prefix:
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
else:
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
msg = system_message + conversation
msg = msg.strip()
msg = system_message + conversation
msg = msg.strip()
return msg
@@ -126,7 +137,7 @@ model_vmfb_key = ""
# TODO: Make chat reusable for UI and API
def chat(
curr_system_message,
prompt_prefix,
history,
model,
device,
@@ -140,6 +151,7 @@ def chat(
global model_vmfb_key
global vicuna_model
device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
@@ -148,6 +160,7 @@ def chat(
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
@@ -158,18 +171,53 @@ def chat(
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{precision}"
if new_model_vmfb_key != model_vmfb_key:
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
# get iree flags that need to be overridden, from commandline args
_extra_args = []
# vulkan target triple
if args.iree_vulkan_target_triple != "":
vulkan_target_triple = args.iree_vulkan_target_triple
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
get_vulkan_target_triple,
)
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
# We already have the device_id extracted via WebUI, so we directly use
# that to find the target triple.
vulkan_target_triple = get_vulkan_target_triple(
vulkaninfo_list[device_id]
)
_extra_args.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
f"-iree-vulkan-target-triple={vulkan_target_triple}"
)
if "rdna" in vulkan_target_triple:
flags_to_add = [
"--iree-spirv-index-bits=64",
]
_extra_args = _extra_args + flags_to_add
if device_id is None:
id = 0
for device in vulkaninfo_list:
target_triple = get_vulkan_target_triple(
vulkaninfo_list[id]
)
if target_triple == vulkan_target_triple:
device_id = id
break
id += 1
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use target triple : {vulkan_target_triple}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
@@ -188,32 +236,47 @@ def chat(
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
vulkan_target_triple=vulkan_target_triple,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
device_id=device_id,
)
prompt = create_prompt(model_name, history)
if vicuna_model is None:
sys.exit("Unable to instantiate the model object, exiting.")
prompt = create_prompt(model_name, history, prompt_prefix)
partial_text = ""
count = 0
start_time = time.time()
for text, msg in progress.tqdm(
token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
count += 1
if "formatted" in msg:
history[-1][1] = text
end_time = time.time()
tokens_per_sec = count / (end_time - start_time)
yield history, str(format(tokens_per_sec, ".2f")) + " tokens/sec"
else:
if msg is None:
if is_first:
prefill_time = exec_time
is_first = False
else:
total_time_ms += exec_time
token_count += 1
partial_text += text + " "
history[-1][1] = partial_text
yield history, ""
yield history, f"Prefill: {prefill_time:.2f}"
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
)
return history, ""
@@ -251,6 +314,7 @@ def llm_chat_api(InputData: dict):
UnshardedVicuna,
)
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
@@ -259,6 +323,7 @@ def llm_chat_api(InputData: dict):
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
else:
print("unrecognized device")
@@ -271,6 +336,7 @@ def llm_chat_api(InputData: dict):
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
device_id=device_id,
)
# TODO: add role dict for different models
@@ -335,6 +401,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
label="Select Model",
value=model_choices[0],
choices=model_choices,
allow_custom_value=True,
)
supported_devices = available_devices
enabled = len(supported_devices) > 0
@@ -348,25 +415,31 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
# multiselect=True,
)
precision = gr.Radio(
label="Precision",
value="int8",
value="int4",
choices=[
"int4",
"int8",
"fp16",
],
visible=True,
visible=False,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
interactive=True,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=False,
interactive=True,
)
with gr.Row(visible=False):
with gr.Group():
@@ -393,16 +466,17 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
system_msg = gr.Textbox(
start_message, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat,
inputs=[
system_msg,
prompt_prefix,
chatbot,
model,
device,
@@ -411,14 +485,19 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat,
inputs=[
system_msg,
prompt_prefix,
chatbot,
model,
device,
@@ -427,6 +506,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
stop.click(

View File

@@ -52,8 +52,7 @@ def txt2img_inf(
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -91,21 +90,17 @@ def txt2img_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files():
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -145,6 +140,11 @@ def txt2img_inf(
args.max_length = max_length
args.height = height
args.width = width
args.use_hiresfix = use_hiresfix
args.hiresfix_height = hiresfix_height
args.hiresfix_width = hiresfix_width
args.hiresfix_strength = hiresfix_strength
args.resample_type = resample_type
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform
@@ -339,8 +339,7 @@ def txt2img_api(
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
@@ -389,32 +388,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
# janky fix for overflowing text
t2i_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
t2i_model_info = (
f"Custom Model Path: {t2i_model_info}"
)
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
txt2img_custom_model = gr.Dropdown(
label=f"Models",
info=t2i_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
choices=get_custom_model_files()
+ predefined_models,
)
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the dropdown "
"on the left and enter model ID here.",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL.",
lines=3,
allow_custom_value=True,
scale=2,
)
# janky fix for overflowing text
t2i_vae_info = (
@@ -430,6 +415,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
else "None",
choices=["None"]
+ get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Column(scale=1, min_width=170):
txt2img_png_info_img = gr.Image(
@@ -466,6 +453,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -484,6 +472,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Scheduler",
value=args.scheduler,
choices=scheduler_list,
allow_custom_value=True,
)
with gr.Column():
save_metadata_to_png = gr.Checkbox(
@@ -568,6 +557,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
"Cubic",
],
label="Resample Type",
allow_custom_value=True,
)
hiresfix_height = gr.Slider(
384,
@@ -624,6 +614,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
@@ -643,7 +634,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{t2i_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
@@ -686,7 +678,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
batch_size,
scheduler,
txt2img_custom_model,
txt2img_hf_model_id,
custom_vae,
precision,
device,
@@ -736,7 +727,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
width,
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
@@ -752,7 +742,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
width,
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,

View File

@@ -46,8 +46,7 @@ def upscaler_inf(
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -85,21 +84,17 @@ def upscaler_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files(custom_checkpoint_type="upscaler"):
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -304,8 +299,7 @@ def upscaler_api(
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
@@ -346,36 +340,22 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
upscaler_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
upscaler_model_info = (
f"Custom Model Path: {upscaler_model_info}"
f"Custom Model Path: {str(get_custom_model_path())}"
)
upscaler_custom_model = gr.Dropdown(
label=f"Models",
info=upscaler_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-x4-upscaler",
choices=["None"]
+ get_custom_model_files(
choices=get_custom_model_files(
custom_checkpoint_type="upscaler"
)
+ predefined_upscaler_models,
)
upscaler_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
allow_custom_value=True,
scale=2,
)
# janky fix for overflowing text
upscaler_vae_info = (
@@ -390,6 +370,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -425,6 +407,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -443,6 +426,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Scheduler",
value="DDIM",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -547,17 +531,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -569,14 +544,26 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{upscaler_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
lines=2,
elem_id="std_output",
show_label=False,
)
upscaler_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -600,7 +587,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
batch_size,
scheduler,
upscaler_custom_model,
upscaler_hf_model_id,
custom_vae,
precision,
device,

View File

@@ -149,7 +149,6 @@ def import_png_metadata(
width,
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
@@ -175,10 +174,8 @@ def import_png_metadata(
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
hf_model_id = ""
if "Model" in metadata and png_hf_model_id:
custom_model = "None"
hf_model_id = png_hf_model_id
elif "Model" in metadata and png_hf_model_id:
custom_model = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
@@ -217,7 +214,6 @@ def import_png_metadata(
width,
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,

View File

@@ -1,192 +0,0 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.17)
project(sharkbackend LANGUAGES C CXX)
#
# Options
#
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()
#
# Dependencies
#
# FetchContent requires us to include the transitive closure of all
# repos that we depend on so that we can override the tags.
#
include(FetchContent)
FetchContent_Declare(
repo-common
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
GIT_TAG ${TRITON_COMMON_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-core
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
GIT_TAG ${TRITON_CORE_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-backend
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
#
# The backend must be built into a shared library. Use an ldscript to
# hide all symbols except for the TRITONBACKEND API.
#
configure_file(src/libtriton_dshark.ldscript libtriton_dshark.ldscript COPYONLY)
add_library(
triton-dshark-backend SHARED
src/dshark.cc
#src/dshark_driver_module.c
)
add_library(
SharkBackend::triton-dshark-backend ALIAS triton-dshark-backend
)
target_include_directories(
triton-dshark-backend
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src
)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
add_subdirectory(thirdparty/srt EXCLUDE_FROM_ALL)
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
iree_hal_hal
iree_hal_cuda_cuda
iree_hal_cuda_registration_registration
iree_hal_vmvx_registration_registration
iree_hal_dylib_registration_registration
iree_modules_hal_hal
iree_vm_vm
iree_vm_bytecode_module
iree_hal_local_loaders_system_library_loader
iree_hal_local_loaders_vmvx_module_loader
)
target_compile_features(triton-dshark-backend PRIVATE cxx_std_11)
target_link_libraries(
triton-dshark-backend
PRIVATE
triton-core-serverapi # from repo-core
triton-core-backendapi # from repo-core
triton-core-serverstub # from repo-core
triton-backend-utils # from repo-backend
)
if(WIN32)
set_target_properties(
triton-dshark-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_dshark
)
else()
set_target_properties(
triton-dshark-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_dshark
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_dshark.ldscript
LINK_FLAGS "-Wl,--version-script libtriton_dshark.ldscript"
)
endif()
#
# Install
#
include(GNUInstallDirs)
set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/SharkBackend)
install(
TARGETS
triton-dshark-backend
EXPORT
triton-dshark-backend-targets
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
)
install(
EXPORT
triton-dshark-backend-targets
FILE
SharkBackendTargets.cmake
NAMESPACE
SharkBackend::
DESTINATION
${INSTALL_CONFIGDIR}
)
include(CMakePackageConfigHelpers)
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/cmake/SharkBackendConfig.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
)
install(
FILES
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
DESTINATION ${INSTALL_CONFIGDIR}
)
#
# Export from build tree
#
export(
EXPORT triton-dshark-backend-targets
FILE ${CMAKE_CURRENT_BINARY_DIR}/SharkBackendTargets.cmake
NAMESPACE SharkBackend::
)
export(PACKAGE SharkBackend)

View File

@@ -1,100 +0,0 @@
# SHARK Triton Backend
The triton backend for shark.
# Build
Install SHARK
```
git clone https://github.com/nod-ai/SHARK.git
# skip above step if dshark is already installed
cd SHARK/inference
```
install dependancies
```
apt-get install patchelf rapidjson-dev python3-dev
git submodule update --init
```
update the submodules of iree
```
cd thirdparty/srt
git submodule update --init
```
Next, make the backend and install it
```
cd ../..
mkdir build && cd build
cmake -DTRITON_ENABLE_GPU=ON \
-DIREE_HAL_DRIVER_CUDA=ON \
-DIREE_TARGET_BACKEND_CUDA=ON \
-DMLIR_ENABLE_CUDA_RUNNER=ON \
-DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \
-DTRITON_BACKEND_REPO_TAG=r22.02 \
-DTRITON_CORE_REPO_TAG=r22.02 \
-DTRITON_COMMON_REPO_TAG=r22.02 ..
make install
```
# Incorporating into Triton
There are much more in depth explenations for the following steps in triton's documentation:
https://github.com/triton-inference-server/server/blob/main/docs/compose.md#triton-with-unsupported-and-custom-backends
There should be a file at /build/install/backends/dshark/libtriton_dshark.so. You will need to copy it into your triton server image.
More documentation is in the link above, but to create the docker image, you need to run the compose.py command in the triton-backend server repo
To first build your image, clone the tritonserver repo.
```
git clone https://github.com/triton-inference-server/server.git
```
then run `compose.py` to build a docker compose file
```
cd server
python3 compose.py --repoagent checksum --dry-run
```
Because dshark is a third party backend, you will need to manually modify the `Dockerfile.compose` to include the dshark backend. To do this, in the Dockerfile.compose file produced, copy this line.
the dshark backend will be located in the build folder from earlier under `/build/install/backends`
```
COPY /path/to/build/install/backends/dshark /opt/tritonserver/backends/dshark
```
Next run
```
docker build -t tritonserver_custom -f Dockerfile.compose .
docker run -it --gpus=1 --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
```
where `path/to/model_repos` is where you are storing the models you want to run
if your not using gpus, omit `--gpus=1`
```
docker run -it --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
```
# Setting up a model
to include a model in your backend, add a directory with your model name to your model repository directory. examples of models can be seen here: https://github.com/triton-inference-server/backend/tree/main/examples/model_repos/minimal_models
make sure to adjust the input correctly in the config.pbtxt file, and save a vmfb file under 1/model.vmfb
# CUDA
if you're having issues with cuda, make sure your correct drivers are installed, and that `nvidia-smi` works, and also make sure that the nvcc compiler is on the path.

View File

@@ -1,39 +0,0 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include(CMakeFindDependencyMacro)
get_filename_component(
SHARKBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
)
list(APPEND CMAKE_MODULE_PATH ${SHARKBACKEND_CMAKE_DIR})
if(NOT TARGET SharkBackend::triton-dshark-backend)
include("${SHARKBACKEND_CMAKE_DIR}/SharkBackendTargets.cmake")
endif()
set(SHARKBACKEND_LIBRARIES SharkBackend::triton-dshark-backend)

File diff suppressed because it is too large Load Diff

View File

@@ -1,30 +0,0 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
{
global:
TRITONBACKEND_*;
local: *;
};

View File

@@ -8,19 +8,8 @@ torchvision
tqdm
#iree-compiler | iree-runtime should already be installed
#these dont work ok osx
#iree-tools-tflite
#iree-tools-xla
#iree-tools-tf
# TensorFlow and JAX.
gin-config
tensorflow-macos
tensorflow-metal
#tf-models-nightly
#tensorflow-text-nightly
transformers
tensorflow-probability
#jax[cpu]
# tflitehub dependencies.

View File

@@ -9,23 +9,13 @@ tabulate
tqdm
#iree-compiler | iree-runtime should already be installed
iree-tools-tflite
iree-tools-xla
iree-tools-tf
# TensorFlow and JAX.
# Modelling and JAX.
gin-config
tf-nightly
keras
#tf-models-nightly
#tensorflow-text-nightly
transformers
diffusers
#tensorflow-probability
#jax[cpu]
# tflitehub dependencies.
Pillow
# Testing and support.

View File

@@ -25,7 +25,7 @@ diffusers
accelerate
scipy
ftfy
gradio
gradio==3.44.3
altair
omegaconf
# 0.3.2 doesn't have binaries for arm64
@@ -47,4 +47,4 @@ pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@dev
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea

View File

@@ -86,6 +86,7 @@ $PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
if [ "$torch_mlir_bin" = true ]; then
if [[ $(uname -s) = 'Darwin' ]]; then
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
$PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
else
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
@@ -128,7 +129,13 @@ if [[ ! -z "${IMPORTER}" ]]; then
fi
fi
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/cpu/
if [[ $(uname -s) = 'Darwin' ]]; then
PYTORCH_URL=https://download.pytorch.org/whl/nightly/torch/
else
PYTORCH_URL=https://download.pytorch.org/whl/nightly/cpu/
fi
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL}
if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
T_VER=$($PYTHON -m pip show torch | grep Version)

View File

@@ -177,7 +177,7 @@ def compile_through_fx(model, inputs, mlir_loc=None):
mlir_model = str(module)
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
mlir_model, device=args.device, mlir_dialect="linalg"
)
shark_module.compile()

View File

@@ -54,7 +54,7 @@ if __name__ == "__main__":
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=False, tracing_required=False
)
shark_module = SharkInference(minilm_mlir, func_name, mlir_dialect="mhlo")
shark_module = SharkInference(minilm_mlir, mlir_dialect="mhlo")
shark_module.compile()
output_idx = 0
data_idx = 1

View File

@@ -6,7 +6,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
)
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"
mlir_model, device="cpu", mlir_dialect="tm_tensor"
)
shark_module.compile()
result = shark_module.forward(inputs)

View File

@@ -13,9 +13,7 @@ arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
print("Running shark on cpu backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo"
)
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
# Generate the random inputs and feed into the graph.
x = shark_module.generate_random_inputs()
@@ -23,15 +21,11 @@ shark_module.compile()
print(shark_module.forward(x))
print("Running shark on cuda backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo"
)
shark_module = SharkInference(mhlo_ir, device="cuda", mlir_dialect="mhlo")
shark_module.compile()
print(shark_module.forward(x))
print("Running shark on vulkan backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo"
)
shark_module = SharkInference(mhlo_ir, device="vulkan", mlir_dialect="mhlo")
shark_module.compile()
print(shark_module.forward(x))

View File

@@ -8,9 +8,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
)
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module = SharkInference(mlir_model, device="cpu", mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward(inputs)
print("The obtained result via shark is: ", result)

View File

@@ -33,7 +33,7 @@ mlir_importer = SharkImporter(
print(golden_out)
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
shark_module = SharkInference(vision_mlir, mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input,))
print("Obtained result", result)

View File

@@ -49,9 +49,7 @@ module = torch_mlir.compile(
mlir_model = module
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device="cuda", mlir_dialect="linalg"
)
shark_module = SharkInference(mlir_model, device="cuda", mlir_dialect="linalg")
shark_module.compile()

View File

@@ -360,7 +360,7 @@ mlir_importer = SharkImporter(
)
shark_module = SharkInference(
dlrm_mlir, func_name, device="vulkan", mlir_dialect="linalg"
dlrm_mlir, device="vulkan", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(input_dlrm)

View File

@@ -294,7 +294,7 @@ def test_dlrm() -> None:
)
shark_module = SharkInference(
dlrm_mlir, func_name, device="cpu", mlir_dialect="linalg"
dlrm_mlir, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)

View File

@@ -33,7 +33,7 @@ mlir_importer = SharkImporter(
tracing_required=False
)
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
shark_module = SharkInference(vision_mlir, mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input,))
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)

View File

@@ -7,7 +7,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
)
shark_module = SharkInference(
mlir_model, func_name, device="vulkan", mlir_dialect="linalg"
mlir_model, device="vulkan", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)

View File

@@ -43,7 +43,7 @@ def get_iree_device_args(device, extra_args=[]):
else:
device_num = 0
if device_uri[0] == "cpu":
if "cpu" in device:
from shark.iree_utils.cpu_utils import get_iree_cpu_args
data_tiling_flag = ["--iree-opt-data-tiling"]
@@ -55,6 +55,8 @@ def get_iree_device_args(device, extra_args=[]):
+ data_tiling_flag
+ u_kernel_flag
+ stack_size_flag
+ ["--iree-flow-enable-quantized-matmul-reassociation"]
+ ["--iree-llvmcpu-enable-quantized-matmul-reassociation"]
)
if device_uri[0] == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
@@ -292,14 +294,16 @@ def compile_module_to_flatbuffer(
extra_args,
model_name="None",
debug=False,
compile_str=False,
):
# Setup Compile arguments wrt to frontends.
input_type = ""
input_type = "auto"
args = get_iree_frontend_args(frontend)
args += get_iree_device_args(device, extra_args)
args += get_iree_common_args(debug=debug)
args += get_model_specific_args()
args += extra_args
args += shark_args.additional_compile_args
if frontend in ["tensorflow", "tf"]:
input_type = "auto"
@@ -310,10 +314,7 @@ def compile_module_to_flatbuffer(
elif frontend in ["tm_tensor"]:
input_type = ireec.InputType.TM_TENSOR
# TODO: make it simpler.
# Compile according to the input type, else just try compiling.
if input_type != "":
# Currently for MHLO/TOSA.
if compile_str:
flatbuffer_blob = ireec.compile_str(
module,
target_backends=[iree_target_map(device)],
@@ -321,9 +322,10 @@ def compile_module_to_flatbuffer(
input_type=input_type,
)
else:
# Currently for Torch.
flatbuffer_blob = ireec.compile_str(
module,
assert os.path.isfile(module)
flatbuffer_blob = ireec.compile_file(
str(module),
input_type=input_type,
target_backends=[iree_target_map(device)],
extra_args=args,
)
@@ -331,8 +333,12 @@ def compile_module_to_flatbuffer(
return flatbuffer_blob
def get_iree_module(flatbuffer_blob, device, device_idx=None):
def get_iree_module(
flatbuffer_blob, device, device_idx=None, rt_flags: list = []
):
# Returns the compiled module and the configs.
for flag in rt_flags:
ireert.flags.parse_flag(flag)
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
@@ -354,9 +360,22 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
flatbuffer_blob_or_path,
device: str,
device_idx: int = None,
rt_flags: list = [],
):
print(f"Loading module {flatbuffer_blob_or_path}...")
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
rt_flags.append(flag)
for flag in rt_flags:
print(flag)
ireert.flags.parse_flags(flag)
if "rocm" in device:
device = "rocm"
with DetailLogger(timeout=2.5) as dl:
@@ -383,6 +402,7 @@ def load_vmfb_using_mmap(
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
@@ -402,7 +422,14 @@ def load_vmfb_using_mmap(
)
dl.log(f"mmap {flatbuffer_blob_or_path}")
ctx = ireert.SystemContext(config=config)
for flag in shark_args.additional_runtime_args:
ireert.flags.parse_flags(flag)
dl.log(f"ireert.SystemContext created")
if "vulkan" in device:
# Vulkan pipeline creation consumes significant amount of time.
print(
"\tCompiling Vulkan shaders. This may take a few minutes."
)
ctx.add_vm_module(mmaped_vmfb)
dl.log(f"module initialized")
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
@@ -423,13 +450,21 @@ def get_iree_compiled_module(
frontend: str = "torch",
model_config_path: str = None,
extra_args: list = [],
rt_flags: list = [],
device_idx: int = None,
mmap: bool = False,
debug: bool = False,
compile_str: bool = False,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args, debug
module,
device,
frontend,
model_config_path,
extra_args,
debug,
compile_str,
)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
@@ -438,11 +473,14 @@ def get_iree_compiled_module(
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
if mmap:
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_blob, device, device_idx
flatbuffer_blob, device, device_idx, rt_flags
)
else:
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
flatbuffer_blob,
device,
device_idx=device_idx,
rt_flags=rt_flags,
)
ret_params = {
"vmfb": vmfb,
@@ -457,17 +495,21 @@ def load_flatbuffer(
device: str,
device_idx: int = None,
mmap: bool = False,
rt_flags: list = [],
):
temp_file_to_unlink = None
if mmap:
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_path, device, device_idx
flatbuffer_path, device, device_idx, rt_flags
)
else:
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
flatbuffer_blob,
device,
device_idx=device_idx,
rt_flags=rt_flags,
)
ret_params = {
"vmfb": vmfb,
@@ -486,10 +528,17 @@ def export_iree_module_to_vmfb(
module_name: str = None,
extra_args: list = [],
debug: bool = False,
compile_str: bool = False,
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, mlir_dialect, model_config_path, extra_args, debug
module,
device,
mlir_dialect,
model_config_path,
extra_args,
debug,
compile_str,
)
if module_name is None:
device_name = (
@@ -497,9 +546,9 @@ def export_iree_module_to_vmfb(
)
module_name = f"{mlir_dialect}_{device_name}"
filename = os.path.join(directory, module_name + ".vmfb")
print(f"Saved vmfb in {filename}.")
with open(filename, "wb") as f:
f.write(flatbuffer_blob)
print(f"Saved vmfb in {filename}.")
return filename

View File

@@ -89,24 +89,10 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
def get_iree_metal_args(device_num=0, extra_args=[]):
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]
# Add any metal spefic compilation flags here
res_metal_flag = []
metal_triple_flag = None
for arg in extra_args:
if "-iree-metal-target-platform=" in arg:
print(f"Using target triple {arg} from command line args")
metal_triple_flag = arg
break
if metal_triple_flag is None:
metal_triple_flag = get_metal_triple_flag(extra_args=extra_args)
if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(
"-iree-vulkan-target-triple=m1-moltenvk-macos"
)
res_metal_flag.append(vulkan_target_env)
if len(extra_args) > 0:
res_metal_flag.extend(extra_args)
return res_metal_flag

View File

@@ -57,11 +57,8 @@ def get_version(triple):
@functools.cache
def get_extensions(triple):
def make_ext_list(ext_list):
res = ""
for e in ext_list:
res += e + ", "
res = f"[{res[:-2]}]"
return res
res = ", ".join(ext_list)
return f"[{res}]"
arch, product, os = triple
if arch == "m1":
@@ -119,7 +116,7 @@ def get_extensions(triple):
]
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
ext.append("VK_NV_cooperative_matrix")
ext.append("VK_KHR_cooperative_matrix")
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
ext.append("VK_KHR_shader_integer_dot_product")
return make_ext_list(ext_list=ext)
@@ -247,7 +244,7 @@ def get_vulkan_target_capabilities(triple):
if arch == "rdna3":
# TODO: Get scope value
cap["coopmatCases"] = [
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>"
]
if product == "rx5700xt":
@@ -468,9 +465,9 @@ def get_vulkan_target_capabilities(triple):
cap["variablePointersStorageBuffer"] = True
cap["coopmatCases"] = [
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, accSat = false, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, accSat = false, scope = #vk.scope<Subgroup>",
]
elif arch == "adreno":
@@ -531,7 +528,7 @@ def get_vulkan_target_capabilities(triple):
cmc = ""
for case in v:
cmc += f"#vk.coop_matrix_props<{case}>, "
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
res += f"cooperativeMatrixPropertiesKHR = [{cmc[:-2]}], "
else:
res += f"{k} = {get_comma_sep_str(v)}, "
else:

View File

@@ -23,11 +23,19 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
from shark.parser import shark_args
@functools.cache
def get_all_vulkan_devices():
from iree.runtime import get_driver
driver = get_driver("vulkan")
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return [d["name"] for d in device_list_src]
@functools.cache
def get_vulkan_device_name(device_num=0):
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
vulkaninfo_list = get_all_vulkan_devices()
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
@@ -111,6 +119,8 @@ def get_vulkan_target_triple(device_name):
# Windows: AMD Radeon RX 7900 XTX
elif all(x in device_name for x in ("RX", "7900")):
triple = f"rdna3-7900-{system_os}"
elif all(x in device_name for x in ("Radeon", "780M")):
triple = f"rdna3-780m-{system_os}"
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
triple = f"rdna3-w7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
@@ -178,9 +188,7 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
@functools.cache
def get_iree_vulkan_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={shark_args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
f"--vulkan_vma_allocator={'true' if shark_args.vulkan_vma_allocator else 'false'}",
]
return vulkan_runtime_flags

View File

@@ -14,8 +14,21 @@
import argparse
import os
import shlex
import subprocess
class SplitStrToListAction(argparse.Action):
def __init__(self, option_strings, dest, *args, **kwargs):
super(SplitStrToListAction, self).__init__(
option_strings=option_strings, dest=dest, *args, **kwargs
)
def __call__(self, parser, namespace, values, option_string=None):
del parser, option_string
setattr(namespace, self.dest, shlex.split(" "))
parser = argparse.ArgumentParser(description="SHARK runner.")
parser.add_argument(
@@ -24,6 +37,20 @@ parser.add_argument(
default="cpu",
help="Device on which shark_runner runs. options are cpu, cuda, and vulkan",
)
parser.add_argument(
"--additional_compile_args",
default=list(),
nargs=1,
action=SplitStrToListAction,
help="Additional arguments to pass to the compiler. These are appended as the last arguments.",
)
parser.add_argument(
"--additional_runtime_args",
default=list(),
nargs=1,
action=SplitStrToListAction,
help="Additional arguments to pass to the IREE runtime. These are appended as the last arguments.",
)
parser.add_argument(
"--enable_tf32",
type=bool,
@@ -133,13 +160,6 @@ parser.add_argument(
help="Profiles vulkan device and collects the .rdc info.",
)
parser.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
help="Flag for setting VMA preferredLargeHeapBlockSize for "
"vulkan device, default is 4G.",
)
parser.add_argument(
"--vulkan_validation_layers",
default=False,
@@ -147,11 +167,4 @@ parser.add_argument(
help="Flag for disabling vulkan validation layers when benchmarking.",
)
parser.add_argument(
"--vulkan_vma_allocator",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling / disabling Vulkan VMA Allocator.",
)
shark_args, unknown = parser.parse_known_args()

View File

@@ -84,6 +84,13 @@ class SharkBenchmarkRunner(SharkRunner):
self.extra_args = extra_args
self.import_args = {}
self.temp_file_to_unlink = None
if not os.path.isfile(mlir_module):
print(
"Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
)
self.compile_str = True
else:
self.compile_str = False
SharkRunner.__init__(
self,
mlir_module,
@@ -98,6 +105,7 @@ class SharkBenchmarkRunner(SharkRunner):
".",
self.mlir_dialect,
extra_args=self.extra_args,
compile_str=self.compile_str,
)
params = load_flatbuffer(
self.vmfb_file,

View File

@@ -1,7 +1,7 @@
import os
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
import torch
import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
@@ -130,10 +130,17 @@ def compile_int_precision(
mlir_module = mlir_module.encode("UTF-8")
mlir_module = BytesIO(mlir_module)
bytecode = mlir_module.read()
bytecode_path = os.path.join(
os.getcwd(), f"{extended_model_name}_linalg.mlirbc"
)
with open(bytecode_path, "wb") as f:
f.write(bytecode)
del bytecode
del mlir_module
print(f"Elided IR written for {extended_model_name}")
return bytecode
return bytecode_path
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
mlir_module=bytecode_path, device=device, mlir_dialect="tm_tensor"
)
extra_args = [
"--iree-hal-dump-executable-sources-to=ies",
@@ -148,7 +155,7 @@ def compile_int_precision(
generate_vmfb=generate_vmfb,
extra_args=extra_args,
),
bytecode,
bytecode_path,
)
@@ -201,7 +208,7 @@ def shark_compile_through_fx(
]
else:
(
mlir_module,
bytecode,
_,
) = import_with_fx(
model=model,
@@ -212,6 +219,11 @@ def shark_compile_through_fx(
model_name=extended_model_name,
save_dir=save_dir,
)
mlir_module = save_mlir(
mlir_module=bytecode,
model_name=extended_model_name,
mlir_dialect=mlir_dialect,
)
shark_module = SharkInference(
mlir_module,

View File

@@ -275,11 +275,11 @@ def download_model(
model_dir = os.path.join(WORKDIR, model_dir_name)
tuned_str = "" if tuned is None else "_" + tuned
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
filename = os.path.join(model_dir, model_name + suffix)
mlir_filename = os.path.join(model_dir, model_name + suffix)
print(
f"Verifying that model artifacts were downloaded successfully to {filename}..."
f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..."
)
if not os.path.exists(filename):
if not os.path.exists(mlir_filename):
from tank.generate_sharktank import gen_shark_files
print(
@@ -287,13 +287,11 @@ def download_model(
)
gen_shark_files(model_name, frontend, WORKDIR, import_args)
assert os.path.exists(filename), f"MLIR not found at {filename}"
with open(filename, mode="rb") as f:
mlir_file = f.read()
assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}"
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
inputs_tuple = tuple([inputs[key] for key in inputs])
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
return mlir_file, function_name, inputs_tuple, golden_out_tuple
return mlir_filename, function_name, inputs_tuple, golden_out_tuple

View File

@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Tuple
from collections import defaultdict
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
import torchvision.models as models
import copy
import io
@@ -20,10 +20,16 @@ def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
bytecode_path = save_mlir(
bytecode,
model_name="shark_eager_module",
frontend="torch",
mlir_dialect="tm_tensor",
)
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode,
mlir_module=bytecode_path,
device=device,
mlir_dialect="tm_tensor",
)

View File

@@ -3,8 +3,8 @@ import json
import numpy as np
import torch_mlir
from iree.compiler import compile_str
from shark.shark_importer import import_with_fx, get_f16_inputs
from iree.compiler import compile_file
from shark.shark_importer import import_with_fx, get_f16_inputs, save_mlir
class GenerateConfigFile:
@@ -54,9 +54,15 @@ class GenerateConfigFile:
verbose=False,
)
module = module.operation.get_asm(large_elements_limit=4)
module_file = save_mlir(
module,
model_name="module_pre_split",
frontend="torch",
mlir_dialect="linalg",
)
compiled_module_str = str(
compile_str(
str(module),
compile_file(
module_file,
target_backends=[backend],
extra_args=[
"--compile-to=flow",

View File

@@ -451,6 +451,108 @@ def transform_fx(fx_g, quantized=False):
fx_g.graph.lint()
def gptq_transforms(fx_g):
import torch
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.ones,
torch.ops.aten._to_copy,
]:
if node.kwargs.get("device") == torch.device(device="cuda:0"):
updated_kwargs = node.kwargs.copy()
updated_kwargs["device"] = torch.device(device="cpu")
node.kwargs = updated_kwargs
if node.target in [
torch.ops.aten._to_copy,
]:
if node.kwargs.get("dtype") == torch.bfloat16:
updated_kwargs = node.kwargs.copy()
updated_kwargs["dtype"] = torch.float16
node.kwargs = updated_kwargs
# Inputs of aten.native_layer_norm should be upcasted to fp32.
if node.target in [torch.ops.aten.native_layer_norm]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
node.args = (
new_node_arg0,
node.args[1],
node.args[2],
node.args[3],
node.args[4],
)
# Inputs of aten.mm should be upcasted to fp32.
if node.target in [torch.ops.aten.mm]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
new_node_arg1 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[1], torch.float32),
kwargs={},
)
node.args = (new_node_arg0, new_node_arg1)
# Outputs of aten.mm should be downcasted to fp16.
if type(node.args[0]) == torch.fx.node.Node and node.args[
0
].target in [torch.ops.aten.mm]:
with fx_g.graph.inserting_before(node):
tmp = node.args[0]
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node.args[0],),
kwargs={"dtype": torch.float16},
)
node.args[0].append(new_node)
node.args[0].replace_all_uses_with(new_node)
new_node.args = (tmp,)
new_node.kwargs = {"dtype": torch.float16}
# Inputs of aten._softmax should be upcasted to fp32.
if node.target in [torch.ops.aten._softmax]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
node.args = (new_node_arg0, node.args[1], node.args[2])
# Outputs of aten._softmax should be downcasted to fp16.
if (
type(node.args[0]) == torch.fx.node.Node
and node.args[0].target in [torch.ops.aten._softmax]
and node.target in [torch.ops.aten.expand]
):
with fx_g.graph.inserting_before(node):
tmp = node.args[0]
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node.args[0],),
kwargs={"dtype": torch.float16},
)
node.args[0].append(new_node)
node.args[0].replace_all_uses_with(new_node)
new_node.args = (tmp,)
new_node.kwargs = {"dtype": torch.float16}
fx_g.graph.lint()
# Doesn't replace the None type.
def change_fx_graph_return_to_tuple(fx_g):
for node in fx_g.graph.nodes:
@@ -504,6 +606,7 @@ def import_with_fx(
is_dynamic=False,
tracing_required=False,
precision="fp32",
is_gptq=False,
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
@@ -584,7 +687,7 @@ def import_with_fx(
torch.ops.aten.index_add,
torch.ops.aten.index_add_,
]
if precision in ["int4", "int8"]:
if precision in ["int4", "int8"] and not is_gptq:
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
)
@@ -653,6 +756,10 @@ def import_with_fx(
add_upcast(fx_g)
fx_g.recompile()
if is_gptq:
gptq_transforms(fx_g)
fx_g.recompile()
if mlir_type == "fx":
return fx_g
@@ -685,3 +792,25 @@ def import_with_fx(
mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type)
return mlir_module, func_name
# Saves a .mlir module python object to the directory 'dir' with 'model_name' and returns a path to the saved file.
def save_mlir(
mlir_module,
model_name,
mlir_dialect="linalg",
frontend="torch",
dir=tempfile.gettempdir(),
):
model_name_mlir = (
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
)
if dir == "":
dir = tempfile.gettempdir()
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if frontend == "torch":
with open(mlir_path, "wb") as mlir_file:
mlir_file.write(mlir_module)
return mlir_path

View File

@@ -39,7 +39,7 @@ class SharkInference:
Attributes
----------
mlir_module : str
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
mlir_module or path represented in string; modules from torch-mlir are serialized in bytecode format.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
@@ -65,7 +65,7 @@ class SharkInference:
def __init__(
self,
mlir_module: bytes,
mlir_module,
device: str = "none",
mlir_dialect: str = "linalg",
is_benchmark: bool = False,
@@ -73,8 +73,17 @@ class SharkInference:
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
device_idx: int = None,
mmap: bool = True,
rt_flags: list = [],
):
self.mlir_module = mlir_module
if mlir_module is not None:
if mlir_module and not os.path.isfile(mlir_module):
print(
"Warning: Initializing SharkInference with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
)
self.compile_str = True
else:
self.compile_str = False
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.is_benchmark = is_benchmark
@@ -92,6 +101,7 @@ class SharkInference:
self.shark_runner = None
self.mmap = mmap
self.rt_flags = rt_flags
def compile(self, extra_args=[]):
if self.dispatch_benchmarks is not None:
@@ -126,6 +136,7 @@ class SharkInference:
self.mlir_dialect,
extra_args=extra_args,
device_idx=self.device_idx,
rt_flags=self.rt_flags,
)
if self.dispatch_benchmarks is not None:
@@ -203,6 +214,7 @@ class SharkInference:
module_name=module_name,
extra_args=extra_args,
debug=debug,
compile_str=self.compile_str,
)
# load and return the module.
@@ -211,12 +223,14 @@ class SharkInference:
device=self.device,
compile_vmfb=False,
extra_args=extra_args,
rt_flags=self.rt_flags,
)
params = load_flatbuffer(
path,
self.device,
self.device_idx,
mmap=self.mmap,
rt_flags=self.rt_flags,
)
self.shark_runner.iree_compilation_module = params["vmfb"]
self.shark_runner.iree_config = params["config"]

View File

@@ -45,7 +45,7 @@ class SharkRunner:
Attributes
----------
mlir_module : str
mlir_module represented in string.
mlir_module path, string, or bytecode.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
@@ -72,12 +72,22 @@ class SharkRunner:
extra_args: list = [],
compile_vmfb: bool = True,
device_idx: int = None,
rt_flags: list = [],
):
self.mlir_module = mlir_module
if self.mlir_module is not None:
if not os.path.isfile(mlir_module):
print(
"Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
)
self.compile_str = True
else:
self.compile_str = False
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
self.device_idx = device_idx
self.rt_flags = rt_flags
if check_device_drivers(self.device):
print(device_driver_info(self.device))
@@ -91,6 +101,8 @@ class SharkRunner:
self.mlir_dialect,
extra_args=self.extra_args,
device_idx=self.device_idx,
rt_flags=self.rt_flags,
compile_str=self.compile_str,
)
self.iree_compilation_module = params["vmfb"]
self.iree_config = params["config"]

View File

@@ -15,7 +15,7 @@
from shark.parser import shark_args
from shark.shark_runner import SharkRunner
from shark.backward_makefx import MakeFxModule
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
import numpy as np
from tqdm import tqdm
import sys
@@ -84,6 +84,12 @@ class SharkTrainer:
training=True,
mlir_type=mlir_type,
)
mlir_module = save_mlir(
mlir_module,
model_name="shark_model",
frontend="torch",
mlir_dialect=mlir_type,
)
self.shark_runner = SharkRunner(
mlir_module,
self.device,

View File

@@ -1,24 +1,6 @@
resnet50,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
albert-base-v2,stablehlo,tf,1e-2,1e-2,default,None,False,False,False,"",""
roberta-base,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
bert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","enabled_windows"
camembert-base,stablehlo,tf,1e-2,1e-3,default,None,True,True,True,"",""
dbmdz/convbert-base-turkish-cased,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/iree-org/iree/issues/9971",""
distilbert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
facebook/convnext-tiny-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342","macos"
funnel-transformer/small,stablehlo,tf,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/201",""
google/electra-small-discriminator,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
google/mobilebert-uncased,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"Fails during iree-compile","macos"
google/vit-base-patch16-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,False,"",""
microsoft/MiniLM-L12-H384-uncased,stablehlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
microsoft/layoutlm-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
microsoft/mpnet-base,stablehlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-large-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
@@ -32,14 +14,8 @@ resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
efficientnet-v2-s,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b0,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
efficientnet_b7,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
gpt2,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
t5-base,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
t5-large,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
1 resnet50 bert-base-uncased stablehlo linalg tf torch 1e-2 1e-3 default nhcw-nhwc None False True False False macos
resnet50 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False macos
albert-base-v2 stablehlo tf 1e-2 1e-2 default None False False False
roberta-base stablehlo tf 1e-02 1e-3 default nhcw-nhwc True True True macos
bert-base-uncased stablehlo tf 1e-2 1e-3 default None False False False enabled_windows
camembert-base stablehlo tf 1e-2 1e-3 default None True True True
dbmdz/convbert-base-turkish-cased stablehlo tf 1e-2 1e-3 default nhcw-nhwc True False True https://github.com/iree-org/iree/issues/9971
distilbert-base-uncased stablehlo tf 1e-2 1e-3 default None False False False
facebook/convnext-tiny-224 stablehlo tf 1e-2 1e-3 tf_vit nhcw-nhwc True False True https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342 macos
funnel-transformer/small stablehlo tf 1e-2 1e-3 default None True False True https://github.com/nod-ai/SHARK/issues/201
google/electra-small-discriminator stablehlo tf 1e-2 1e-3 default None False False False
google/mobilebert-uncased stablehlo tf 1e-2 1e-3 default None True False False Fails during iree-compile macos
google/vit-base-patch16-224 stablehlo tf 1e-2 1e-3 tf_vit nhcw-nhwc False False False
microsoft/MiniLM-L12-H384-uncased stablehlo tf 1e-2 1e-3 tf_hf None True False False Fails during iree-compile.
microsoft/layoutlm-base-uncased stablehlo tf 1e-2 1e-3 default None False False False
microsoft/mpnet-base stablehlo tf 1e-2 1e-2 default None True True True
alexnet linalg torch 1e-2 1e-3 default None True False True https://github.com/nod-ai/SHARK/issues/879
bert-base-cased linalg torch 1e-2 1e-3 default None False False True
1 bert-base-uncased bert-base-uncased linalg linalg torch torch 1e-2 1e-3 default None None False True False True False
2 bert-base-uncased_fp16 bert-base-uncased_fp16 linalg linalg torch torch 1e-1 1e-1 default None None True True True True
3 bert-large-uncased bert-large-uncased linalg linalg torch torch 1e-2 1e-3 default None None False True False True False
bert-large-uncased stablehlo tf 1e-2 1e-3 default None False False False
4 facebook/deit-small-distilled-patch16-224 facebook/deit-small-distilled-patch16-224 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False True False True False Fails during iree-compile.
5 google/vit-base-patch16-224 google/vit-base-patch16-224 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False True False True False https://github.com/nod-ai/SHARK/issues/311
6 microsoft/beit-base-patch16-224-pt22k-ft22k microsoft/beit-base-patch16-224-pt22k-ft22k linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False True False True False https://github.com/nod-ai/SHARK/issues/390 macos macos
14 resnet50_fp16 resnet50_fp16 linalg linalg torch torch 1e-2 1e-2 default nhcw-nhwc/img2col nhcw-nhwc/img2col True True True True Numerics issues, awaiting cuda-independent fp16 integration
15 squeezenet1_0 squeezenet1_0 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False False False False macos macos
16 wide_resnet50_2 wide_resnet50_2 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc/img2col nhcw-nhwc/img2col True False False False macos macos
efficientnet-v2-s stablehlo tf 1e-02 1e-3 default nhcw-nhwc False False False macos
17 mnasnet1_0 mnasnet1_0 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc True True True True macos macos
18 efficientnet_b0 efficientnet_b0 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc True True True True https://github.com/nod-ai/SHARK/issues/1487 macos macos
19 efficientnet_b7 efficientnet_b7 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc True True True True https://github.com/nod-ai/SHARK/issues/1487 macos macos
efficientnet_b0 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False
efficientnet_b7 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False Fails on MacOS builder, VK device lost macos
gpt2 stablehlo tf 1e-2 1e-3 default None True False False macos
20 t5-base t5-base linalg linalg torch torch 1e-2 1e-3 default None None True True True True Inputs for seq2seq models in torch currently unsupported. macos macos
t5-base stablehlo tf 1e-2 1e-3 default None False False False macos
21 t5-large t5-large linalg linalg torch torch 1e-2 1e-3 default None None True True True True Inputs for seq2seq models in torch currently unsupported macos macos
t5-large stablehlo tf 1e-2 1e-3 default None False False False macos

View File

@@ -1,3 +1,4 @@
import argparse
import os
import torch
import numpy as np
@@ -10,21 +11,16 @@ from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
OPT_MODEL = "opt-1.3b"
OPT_FS_NAME = "opt-1_3b"
MAX_SEQUENCE_LENGTH = 128
MAX_NEW_TOKENS = 60
def create_module(model_name, tokenizer, device):
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
def create_module(model_name, tokenizer, device, args):
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
"What is the meaning of life?",
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
max_length=args.max_seq_len,
return_tensors="pt",
)
inputs = (
@@ -33,32 +29,34 @@ def create_module(model_name, tokenizer, device):
)
# np.save("model_inputs_0.npy", inputs[0])
# np.save("model_inputs_1.npy", inputs[1])
opt_fs_name = "-".join(
"_".join(args.model_name.split("/")[1].split("-")).split(".")
)
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
mlir_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch.mlir"
if os.path.isfile(mlir_path):
with open(mlir_path, "r") as f:
model_mlir = f.read()
print(f"Loaded .mlir from {mlir_path}")
print(f"Found .mlir from {mlir_path}")
else:
(model_mlir, func_name) = import_with_fx(
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=OPT_FS_NAME,
model_name=opt_fs_name,
return_str=True,
)
with open(mlir_path, "w") as f:
f.write(model_mlir)
print(f"Saved mlir at {mlir_path}")
del model_mlir
shark_module = SharkInference(
model_mlir,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=False,
)
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
vmfb_name = f"{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu"
shark_module.save_module(module_name=vmfb_name, debug=False)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path
@@ -72,11 +70,11 @@ def shouldStop(tokens):
return False
def generate_new_token(shark_model, tokenizer, new_text):
def generate_new_token(shark_model, tokenizer, new_text, args):
model_inputs = tokenizer(
new_text,
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
max_length=args.max_seq_len,
truncation=True,
return_tensors="pt",
)
@@ -105,18 +103,56 @@ def generate_new_token(shark_model, tokenizer, new_text):
return ret_dict
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--max-seq-len", type=int, default=32)
parser.add_argument(
"--model-name",
help="Model name",
type=str,
choices=[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
],
default="facebook/opt-1.3b",
)
parser.add_argument(
"--recompile",
help="If set, recompiles MLIR -> .vmfb",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--plugin-path",
help="path to executable plugin",
type=str,
default=None,
)
args = parser.parse_args()
print("args={}".format(args))
return args
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(
"facebook/" + OPT_MODEL, use_fast=False
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
opt_fs_name = "-".join(
"_".join(args.model_name.split("/")[1].split("-")).split(".")
)
vmfb_path = (
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-task.vmfb"
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
if args.plugin_path is not None:
rt_flags = [f"--executable_plugin={args.plugin_path}"]
else:
rt_flags = []
opt_shark_module = SharkInference(
mlir_module=None, device="cpu-task", rt_flags=rt_flags
)
opt_shark_module = SharkInference(mlir_module=None, device="cpu-task")
if os.path.isfile(vmfb_path):
opt_shark_module.load_module(vmfb_path)
else:
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-task")
vmfb_path = create_module(args.model_name, tokenizer, "cpu-task", args)
opt_shark_module.load_module(vmfb_path)
while True:
try:
@@ -124,9 +160,9 @@ if __name__ == "__main__":
new_text_init = new_text
words_list = []
for i in range(MAX_NEW_TOKENS):
for i in range(args.max_seq_len):
generated_token_op = generate_new_token(
opt_shark_module, tokenizer, new_text
opt_shark_module, tokenizer, new_text, args
)
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]

View File

@@ -6,7 +6,7 @@ import numpy as np
from shark_opt_wrapper import OPTForCausalLMModel
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from transformers import AutoTokenizer, OPTForCausalLM
OPT_MODEL = "facebook/opt-1.3b"
@@ -57,9 +57,10 @@ class OPTModuleTester:
with open(mlir_path, "w") as f:
f.write(mlir_module)
print(f"Saved mlir at {mlir_path}")
del mlir_module
shark_module = SharkInference(
mlir_module,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=self.benchmark,

View File

@@ -18,14 +18,16 @@ import collections
import json
import os
import psutil
import resource
import time
import numpy as np
from typing import Tuple
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
from shark_opt_wrapper import OPTForCausalLMModel
from shark.parser import shark_args
import iree.compiler as ireec
DEVICE = "cpu"
PLATFORM_SHARK = "shark"
@@ -64,12 +66,11 @@ def get_memory_info():
return process.memory_info()
def create_vmfb_module(
def import_mlir_module(
model_name: str,
tokenizer,
device: str,
max_seq_len: int,
recompile_shark: bool,
):
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
opt_base_model.eval()
@@ -88,6 +89,27 @@ def create_vmfb_module(
# np.save("model_inputs_0.npy", inputs[0])
# np.save("model_inputs_1.npy", inputs[1])
opt_fs_name = get_opt_fs_name(model_name)
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
(model_mlir, func_name) = import_with_fx(
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=opt_fs_name,
return_str=True,
)
with open(mlir_path, "w") as f:
f.write(model_mlir)
print(f"Saved mlir at {mlir_path}")
def create_vmfb_module(
model_name: str,
tokenizer,
device: str,
max_seq_len: int,
recompile_shark: bool,
):
opt_fs_name = get_opt_fs_name(model_name)
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
# If MLIR has already been loaded and recompilation is not requested, use
@@ -97,49 +119,48 @@ def create_vmfb_module(
# compilation time can be correctly measured only when MLIR has already been
# loaded.
assert not recompile_shark or has_mlir
if has_mlir:
with open(mlir_path, "r") as f:
model_mlir = f.read()
print(f"Loaded .mlir from {mlir_path}")
else:
(model_mlir, func_name) = import_with_fx(
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=opt_fs_name,
return_str=True,
if not has_mlir:
import_mlir_module(
model_name,
tokenizer,
device,
max_seq_len,
)
with open(mlir_path, "w") as f:
f.write(model_mlir)
print(f"Saved mlir at {mlir_path}")
shark_module = SharkInference(
model_mlir,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=False,
rt_flags=[],
)
vmfb_name = (
f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels"
)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}"
shark_module.save_module(module_name=vmfb_name)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path
def load_shark_model(
model_name: str, max_seq_len: int, recompile_shark: bool
model_name: str,
max_seq_len: int,
recompile_shark: bool,
plugin_path: str = [],
) -> ModelWrapper:
opt_fs_name = get_opt_fs_name(model_name)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb"
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
if recompile_shark or not os.path.isfile(vmfb_name):
print(f"vmfb not found. compiling and saving to {vmfb_name}")
create_vmfb_module(
model_name, tokenizer, DEVICE, max_seq_len, recompile_shark
)
shark_module = SharkInference(mlir_module=None, device="cpu-task")
if plugin_path is not None:
rt_flags = [f"--executable_plugin={plugin_path}"]
else:
rt_flags = []
shark_module = SharkInference(
mlir_module=None, device="cpu-task", rt_flags=rt_flags
)
shark_module.load_module(vmfb_name)
return ModelWrapper(model=shark_module, tokenizer=tokenizer)
@@ -168,7 +189,7 @@ def save_json(data, filename):
def collect_huggingface_logits(
model_name: str, max_seq_len: int, save_json: bool
model_name: str, max_seq_len: int, to_save_json: bool
) -> Tuple[float, float]:
# Load
t0 = time.time()
@@ -194,11 +215,11 @@ def collect_huggingface_logits(
for idx, tokens in enumerate(tokenized_prompts):
print("prompt: {}".format(PROMPTS[idx]))
logits = run_huggingface_model(model_wrapper, tokens)
if save_json:
if to_save_json:
results.append([PROMPTS[idx], logits[0].tolist()])
run_time = time.time() - t0
print("--- Took {} seconds to run Huggingface.".format(run_time))
if save_json:
if to_save_json:
save_json(results, "/tmp/huggingface.json")
run_memory_info = get_memory_info()
return {
@@ -215,11 +236,17 @@ def collect_huggingface_logits(
def collect_shark_logits(
model_name: str, max_seq_len: int, recompile_shark: bool, save_json: bool
model_name: str,
max_seq_len: int,
recompile_shark: bool,
to_save_json: bool,
plugin_path: str,
) -> Tuple[float, float]:
# Load
t0 = time.time()
model_wrapper = load_shark_model(model_name, max_seq_len, recompile_shark)
model_wrapper = load_shark_model(
model_name, max_seq_len, recompile_shark, plugin_path
)
load_time = time.time() - t0
print("--- Took {} seconds to load Shark.".format(load_time))
load_memory_info = get_memory_info()
@@ -246,11 +273,11 @@ def collect_shark_logits(
print("prompt: {}".format(PROMPTS[idx]))
logits = run_shark_model(model_wrapper, tokens)
lst = [e.tolist() for e in logits]
if save_json:
if to_save_json:
results.append([PROMPTS[idx], lst])
run_time = time.time() - t0
print("--- Took {} seconds to run Shark.".format(run_time))
if save_json:
if to_save_json:
save_json(results, "/tmp/shark.json")
platform_postfix = "-compile" if recompile_shark else "-precompiled"
run_memory_info = get_memory_info()
@@ -316,6 +343,12 @@ def parse_args():
choices=[PLATFORM_SHARK, PLATFORM_HUGGINGFACE],
default=PLATFORM_SHARK,
)
parser.add_argument(
"--plugin-path",
help="path to executable plugin",
type=str,
default=None,
)
args = parser.parse_args()
print("args={}".format(args))
return args
@@ -329,6 +362,7 @@ if __name__ == "__main__":
args.max_seq_len,
args.recompile_shark,
args.save_json,
args.plugin_path,
)
print("# Summary: {}".format(json.dumps(shark_report)))
else:

View File

@@ -2,7 +2,7 @@ import os
import torch
from transformers import AutoTokenizer, OPTForCausalLM
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark_opt_wrapper import OPTForCausalLMModel
model_name = "facebook/opt-1.3b"
@@ -25,11 +25,13 @@ inputs = (
model=model,
inputs=inputs,
is_f16=False,
debug=True,
model_name=model_name.split("/")[1],
save_dir=".",
)
mlir_module = save_mlir(
mlir_module,
model_name=model_name.split("/")[1],
frontend="torch",
mlir_dialect="linalg",
)
shark_module = SharkInference(
mlir_module,
device="cpu-sync",

View File

@@ -36,7 +36,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
get_hf_img_cls_model,
get_fp16_model,
)
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
with open(torch_model_list) as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")
@@ -130,133 +130,6 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
)
def save_tf_model(tf_model_list, local_tank_cache, import_args):
from tank.model_utils_tf import (
get_causal_image_model,
get_masked_lm_model,
get_causal_lm_model,
get_keras_model,
get_TFhf_model,
get_tfhf_seq2seq_model,
)
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
with open(tf_model_list) as csvfile:
tf_reader = csv.reader(csvfile, delimiter=",")
fields = next(tf_reader)
for row in tf_reader:
tf_model_name = row[0]
model_type = row[1]
model = None
input = None
print(f"Generating artifacts for model {tf_model_name}")
if model_type == "hf":
model, input, _ = get_masked_lm_model(
tf_model_name, import_args
)
elif model_type == "img":
model, input, _ = get_causal_image_model(
tf_model_name, import_args
)
elif model_type == "keras":
model, input, _ = get_keras_model(tf_model_name, import_args)
elif model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name, import_args)
elif model_type == "tfhf_seq2seq":
model, input, _ = get_tfhf_seq2seq_model(
tf_model_name, import_args
)
elif model_type == "hf_causallm":
model, input, _ = get_causal_lm_model(
tf_model_name, import_args
)
tf_model_name = tf_model_name.replace("/", "_")
if import_args["batch_size"] != 1:
tf_model_dir = os.path.join(
local_tank_cache,
str(tf_model_name)
+ "_tf"
+ f"_BS{str(import_args['batch_size'])}",
)
else:
tf_model_dir = os.path.join(
local_tank_cache, str(tf_model_name) + "_tf"
)
os.makedirs(tf_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
model,
inputs=input,
frontend="tf",
)
mlir_importer.import_debug(
is_dynamic=False,
dir=tf_model_dir,
model_name=tf_model_name,
)
def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
from shark.tflite_utils import TFLitePreprocessor
with open(tflite_model_list) as csvfile:
tflite_reader = csv.reader(csvfile, delimiter=",")
for row in tflite_reader:
print("\n")
tflite_model_name = row[0]
tflite_model_link = row[1]
print("tflite_model_name", tflite_model_name)
print("tflite_model_link", tflite_model_link)
tflite_model_name_dir = os.path.join(
local_tank_cache, str(tflite_model_name) + "_tflite"
)
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
# Preprocess to get SharkImporter input import_args
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()
# Use SharkImporter to get SharkInference input import_args
my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
frontend="tflite",
raw_model_file=raw_model_file_path,
)
my_shark_importer.import_debug(
dir=tflite_model_name_dir,
model_name=tflite_model_name,
func_name="main",
)
mlir_hash = create_hash(
os.path.join(
tflite_model_name_dir,
tflite_model_name + "_tflite" + ".mlir",
)
)
np.save(
os.path.join(tflite_model_name_dir, "hash"),
np.array(mlir_hash),
)
def check_requirements(frontend):
import importlib
@@ -265,10 +138,6 @@ def check_requirements(frontend):
tv_spec = importlib.util.find_spec("torchvision")
has_pkgs = tv_spec is not None
elif frontend in ["tensorflow", "tf"]:
tf_spec = importlib.util.find_spec("tensorflow")
has_pkgs = tf_spec is not None
return has_pkgs
@@ -287,27 +156,11 @@ def gen_shark_files(modelname, frontend, tank_dir, importer_args):
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tf_model_list.csv"
)
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
)
# Create a temporary .csv with only the desired entry.
if frontend == "tf":
with open(tf_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_tf_model(custom_model_csv.name, tank_dir, import_args)
elif frontend == "torch":
if frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
@@ -341,18 +194,6 @@ if __name__ == "__main__":
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
# )
# parser.add_argument(
# "--tf_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tf_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--tflite_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tflite/tflite_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--ci_tank_dir",
# type=bool,
# default=False,
@@ -369,11 +210,5 @@ if __name__ == "__main__":
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
tflite_model_csv = os.path.join(
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
)
save_torch_model(torch_model_csv, WORKDIR, import_args)
# save_tf_model(tf_model_csv, WORKDIR, import_args)
# save_tflite_model(tflite_model_csv, WORKDIR, import_args)

View File

@@ -1,28 +0,0 @@
model_name, model_type
albert-base-v2,hf
bert-base-uncased,hf
camembert-base,hf
dbmdz/convbert-base-turkish-cased,hf
distilbert-base-uncased,hf
google/electra-small-discriminator,hf
funnel-transformer/small,hf
microsoft/layoutlm-base-uncased,hf
google/mobilebert-uncased,hf
microsoft/mpnet-base,hf
roberta-base,hf
resnet50,keras
xlm-roberta-base,hf
microsoft/MiniLM-L12-H384-uncased,TFhf
funnel-transformer/small,hf
microsoft/mpnet-base,hf
facebook/convnext-tiny-224,img
google/vit-base-patch16-224,img
efficientnet-v2-s,keras
bert-large-uncased,hf
t5-base,tfhf_seq2seq
t5-large,tfhf_seq2seq
efficientnet_b0,keras
efficientnet_b7,keras
gpt2,hf_causallm
t5-base,tfhf_seq2seq
t5-large,tfhf_seq2seq
1 model_name model_type
2 albert-base-v2 hf
3 bert-base-uncased hf
4 camembert-base hf
5 dbmdz/convbert-base-turkish-cased hf
6 distilbert-base-uncased hf
7 google/electra-small-discriminator hf
8 funnel-transformer/small hf
9 microsoft/layoutlm-base-uncased hf
10 google/mobilebert-uncased hf
11 microsoft/mpnet-base hf
12 roberta-base hf
13 resnet50 keras
14 xlm-roberta-base hf
15 microsoft/MiniLM-L12-H384-uncased TFhf
16 funnel-transformer/small hf
17 microsoft/mpnet-base hf
18 facebook/convnext-tiny-224 img
19 google/vit-base-patch16-224 img
20 efficientnet-v2-s keras
21 bert-large-uncased hf
22 t5-base tfhf_seq2seq
23 t5-large tfhf_seq2seq
24 efficientnet_b0 keras
25 efficientnet_b7 keras
26 gpt2 hf_causallm
27 t5-base tfhf_seq2seq
28 t5-large tfhf_seq2seq