mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
168 Commits
20230820.9
...
fix-shardi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eab2194ca1 | ||
|
|
93f583f0be | ||
|
|
e5ed167f03 | ||
|
|
051ba5de63 | ||
|
|
6384780d16 | ||
|
|
db0c53ae59 | ||
|
|
ce9ce3a7c8 | ||
|
|
d72da3801f | ||
|
|
9c50edc664 | ||
|
|
a1b7110550 | ||
|
|
ff15fd74f6 | ||
|
|
552b2c3ee3 | ||
|
|
795fc33001 | ||
|
|
2910841fe6 | ||
|
|
396a054856 | ||
|
|
5c66948d4f | ||
|
|
ed3dda94c0 | ||
|
|
d31d28b082 | ||
|
|
78c607e1d3 | ||
|
|
666e601dd9 | ||
|
|
ca58908e5b | ||
|
|
1f5b39f56e | ||
|
|
2da31c4109 | ||
|
|
da50a16242 | ||
|
|
ce38d49f05 | ||
|
|
2f780f0d38 | ||
|
|
d051c3a4a7 | ||
|
|
1b11c82c9d | ||
|
|
80a33d427f | ||
|
|
4125a26294 | ||
|
|
905d0103ff | ||
|
|
192b3b2c61 | ||
|
|
8f9adc4a2a | ||
|
|
70817bb50a | ||
|
|
dd37c26d36 | ||
|
|
a708879c6c | ||
|
|
bb1b49eb6f | ||
|
|
f6d41affd9 | ||
|
|
c2163488d8 | ||
|
|
54bff4611d | ||
|
|
11510d5111 | ||
|
|
32cab73a29 | ||
|
|
392bade0bf | ||
|
|
91df5f0613 | ||
|
|
df20cf9c8a | ||
|
|
c4a908c3ea | ||
|
|
6285430d8a | ||
|
|
51afe19e20 | ||
|
|
31005bcf73 | ||
|
|
f41ad87ef6 | ||
|
|
d811524a00 | ||
|
|
51e1bd1c5d | ||
|
|
db89b1bdc1 | ||
|
|
2754e2e257 | ||
|
|
ab0e870c43 | ||
|
|
fb30e8c226 | ||
|
|
a07d542400 | ||
|
|
ad55cb696f | ||
|
|
488a172292 | ||
|
|
500c4f2306 | ||
|
|
92b694db4d | ||
|
|
322874f7f9 | ||
|
|
5001db3415 | ||
|
|
71846344a2 | ||
|
|
72e27c96fc | ||
|
|
7963abb8ec | ||
|
|
98244232dd | ||
|
|
679a452139 | ||
|
|
72c0a8abc8 | ||
|
|
ea920f2955 | ||
|
|
486202377a | ||
|
|
0c38c33d0a | ||
|
|
841773fa32 | ||
|
|
0361db46f9 | ||
|
|
a012433ffd | ||
|
|
5061193da3 | ||
|
|
bff48924be | ||
|
|
825b36cbdd | ||
|
|
134441957d | ||
|
|
7cd14fdc47 | ||
|
|
e6cb5cef57 | ||
|
|
66abee8e5b | ||
|
|
4797bb89f5 | ||
|
|
205e57683a | ||
|
|
2866d665ee | ||
|
|
71d25ec5d8 | ||
|
|
202ffff67b | ||
|
|
0b77059628 | ||
|
|
a208302bb9 | ||
|
|
b83d32fafe | ||
|
|
0a618e1863 | ||
|
|
a731eb6ed4 | ||
|
|
2004d16945 | ||
|
|
6e409bfb77 | ||
|
|
77727d149c | ||
|
|
66f6e79d68 | ||
|
|
3b825579a7 | ||
|
|
9f0a421764 | ||
|
|
c28682110c | ||
|
|
caf6cc5d8f | ||
|
|
8614a18474 | ||
|
|
86c1c0c215 | ||
|
|
8bb364bcb8 | ||
|
|
7abddd01ec | ||
|
|
2a451fa0c7 | ||
|
|
9c4610b9da | ||
|
|
a38cc9d216 | ||
|
|
1c382449ec | ||
|
|
7cc9b3f8e8 | ||
|
|
e54517e967 | ||
|
|
326327a799 | ||
|
|
785b65c7b0 | ||
|
|
0d16c81687 | ||
|
|
8dd7850c69 | ||
|
|
e930ba85b4 | ||
|
|
cd732e7a38 | ||
|
|
8e0f8b3227 | ||
|
|
b8210ef796 | ||
|
|
94594542a9 | ||
|
|
82f833e87d | ||
|
|
c9d6870105 | ||
|
|
4fec03a6cc | ||
|
|
9a27f51378 | ||
|
|
ad1a0f35ff | ||
|
|
6773278ec2 | ||
|
|
9a0efffcca | ||
|
|
61c6f153d9 | ||
|
|
effd42e8f5 | ||
|
|
b5fbb1a8a0 | ||
|
|
ded74d09cd | ||
|
|
79267931c1 | ||
|
|
9eceba69b7 | ||
|
|
ca609afb6a | ||
|
|
11bdce9790 | ||
|
|
684943a4a6 | ||
|
|
b817bb8455 | ||
|
|
780f520f02 | ||
|
|
c61b6f8d65 | ||
|
|
c854208d49 | ||
|
|
c5dcfc1f13 | ||
|
|
bde63ee8ae | ||
|
|
9681d494eb | ||
|
|
ede6bf83e2 | ||
|
|
2c2693fb7d | ||
|
|
1d31b2b2c6 | ||
|
|
d2f64eefa3 | ||
|
|
87ae14b6ff | ||
|
|
1ccafa1fc1 | ||
|
|
4c3d8a0a7f | ||
|
|
3601dc7c3b | ||
|
|
671881cf87 | ||
|
|
4e9be6be59 | ||
|
|
9c8cbaf498 | ||
|
|
9e348a114e | ||
|
|
51f90a4d56 | ||
|
|
310d5d0a49 | ||
|
|
9697981004 | ||
|
|
450c231171 | ||
|
|
07f6f4a2f7 | ||
|
|
610813c72f | ||
|
|
8e3860c9e6 | ||
|
|
e37d6720eb | ||
|
|
16160d9a7d | ||
|
|
79075a1a07 | ||
|
|
db990826d3 | ||
|
|
7ee3e4ba5d | ||
|
|
05889a8fe1 | ||
|
|
b87efe7686 |
15
.github/workflows/test-models.yml
vendored
15
.github/workflows/test-models.yml
vendored
@@ -112,7 +112,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
|
||||
pytest --benchmark=native --update_tank -k cpu
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
python build_tools/vicuna_testing.py
|
||||
@@ -121,9 +121,9 @@ jobs:
|
||||
if: matrix.suite == 'cuda'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
|
||||
pytest --benchmark=native --update_tank -k cuda
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
# Disabled due to black image bug
|
||||
@@ -137,16 +137,17 @@ jobs:
|
||||
source shark.venv/bin/activate
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
|
||||
# disabled due to a low-visibility memory issue with pytest on macos.
|
||||
# pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
pytest --update_tank -k vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail
|
||||
|
||||
- name: Validate Vulkan Models (Windows)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
|
||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -182,7 +182,7 @@ generated_imgs/
|
||||
|
||||
# Custom model related artefacts
|
||||
variants.json
|
||||
models/
|
||||
/models/
|
||||
|
||||
# models folder
|
||||
apps/stable_diffusion/web/models/
|
||||
@@ -193,3 +193,12 @@ stencil_annotator/
|
||||
# For DocuChat
|
||||
apps/language_models/langchain/user_path/
|
||||
db_dir_UserData
|
||||
|
||||
# Embeded browser cache and other
|
||||
apps/stable_diffusion/web/EBWebView/
|
||||
|
||||
# Llama2 tokenizer configs
|
||||
llama2_tokenizer_configs/
|
||||
|
||||
# Webview2 runtime artefacts
|
||||
EBWebView/
|
||||
|
||||
12
README.md
12
README.md
@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
|
||||
<summary>Prerequisites - Drivers </summary>
|
||||
|
||||
#### Install your Windows hardware drivers
|
||||
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
|
||||
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
|
||||
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
@@ -254,7 +254,6 @@ if you want to instead incorporate this into a python script, you can pass the `
|
||||
```
|
||||
shark_module = SharkInference(
|
||||
mlir_model,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
dispatch_benchmarks="all",
|
||||
@@ -297,7 +296,7 @@ torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
|
||||
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input))
|
||||
|
||||
@@ -320,12 +319,17 @@ mhlo_ir = r"""builtin.module {
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
|
||||
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((arg0, arg1))
|
||||
```
|
||||
</details>
|
||||
|
||||
## Examples Using the REST API
|
||||
|
||||
* [Setting up SHARK for use with Blender](./docs/shark_sd_blender.md)
|
||||
* [Setting up SHARK for use with Koboldcpp](./docs/shark_sd_koboldcpp.md)
|
||||
|
||||
## Supported and Validated Models
|
||||
|
||||
SHARK is maintained to support the latest innovations in ML Models:
|
||||
|
||||
@@ -20,12 +20,12 @@ import gc
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
# Brevitas
|
||||
from typing import List, Tuple
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ class H2OGPTModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=128,
|
||||
@@ -237,7 +237,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
@@ -256,6 +256,11 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
bytecode = save_mlir(
|
||||
bytecode,
|
||||
model_name=f"h2ogpt_{precision}",
|
||||
frontend="torch",
|
||||
)
|
||||
return bytecode
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
|
||||
@@ -65,8 +65,8 @@ tiktoken==0.4.0
|
||||
openai==0.27.8
|
||||
|
||||
# optional for chat with PDF
|
||||
langchain==0.0.202
|
||||
pypdf==3.12.2
|
||||
langchain==0.0.329
|
||||
pypdf==3.17.0
|
||||
# avoid textract, requires old six
|
||||
#textract==1.6.5
|
||||
|
||||
|
||||
442
apps/language_models/scripts/llama_ir_conversion_utils.py
Normal file
442
apps/language_models/scripts/llama_ir_conversion_utils.py
Normal file
@@ -0,0 +1,442 @@
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
from argparse import RawTextHelpFormatter
|
||||
import re, gc
|
||||
|
||||
"""
|
||||
This script can be used as a standalone utility to convert IRs to dynamic + combine them.
|
||||
Following are the various ways this script can be used :-
|
||||
a. To convert a single Linalg IR to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO FIRST IR>
|
||||
b. To convert two Linalg IRs to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>
|
||||
c. To combine two Linalg IRs into one:
|
||||
--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
d. To convert both IRs into dynamic as well as combine the IRs:
|
||||
--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
|
||||
NOTE: For dynamic you'll also need to provide the following set of flags:-
|
||||
i. For First Llama : --dynamic_input_size (DEFAULT: 19)
|
||||
ii. For Second Llama: --model_name (DEFAULT: llama2_7b)
|
||||
--precision (DEFAULT: 'int4')
|
||||
You may use --save_dynamic to also save the dynamic IR in option d above.
|
||||
Else for option a. and b. the dynamic IR(s) will get saved by default.
|
||||
"""
|
||||
|
||||
|
||||
def combine_mlir_scripts(
|
||||
first_vicuna_mlir,
|
||||
second_vicuna_mlir,
|
||||
output_name,
|
||||
return_ir=True,
|
||||
):
|
||||
print(f"[DEBUG] combining first and second mlir")
|
||||
print(f"[DEBUG] output_name = {output_name}")
|
||||
maps1 = []
|
||||
maps2 = []
|
||||
constants = set()
|
||||
f1 = []
|
||||
f2 = []
|
||||
|
||||
print(f"[DEBUG] processing first vicuna mlir")
|
||||
first_vicuna_mlir = first_vicuna_mlir.splitlines()
|
||||
while first_vicuna_mlir:
|
||||
line = first_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps1.append(line)
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "first_vicuna_forward", line)
|
||||
f1.append(line)
|
||||
f1 = f1[:-1]
|
||||
del first_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps1):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
|
||||
maps1[i] = map_line
|
||||
f1 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
|
||||
for func_line in f1
|
||||
]
|
||||
|
||||
print(f"[DEBUG] processing second vicuna mlir")
|
||||
second_vicuna_mlir = second_vicuna_mlir.splitlines()
|
||||
while second_vicuna_mlir:
|
||||
line = second_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps2.append(line)
|
||||
elif "global_seed" in line:
|
||||
continue
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "second_vicuna_forward", line)
|
||||
f2.append(line)
|
||||
f2 = f2[:-1]
|
||||
del second_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps2):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
|
||||
maps2[i] = map_line
|
||||
f2 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
|
||||
for func_line in f2
|
||||
]
|
||||
|
||||
module_start = 'module attributes {torch.debug_module_name = "_lambda"} {'
|
||||
module_end = "}"
|
||||
|
||||
global_vars = []
|
||||
vnames = []
|
||||
global_var_loading1 = []
|
||||
global_var_loading2 = []
|
||||
|
||||
print(f"[DEBUG] processing constants")
|
||||
counter = 0
|
||||
constants = list(constants)
|
||||
while constants:
|
||||
constant = constants.pop(0)
|
||||
vname, vbody = constant.split("=")
|
||||
vname = re.sub("%", "", vname)
|
||||
vname = vname.strip()
|
||||
vbody = re.sub("arith.constant", "", vbody)
|
||||
vbody = vbody.strip()
|
||||
if len(vbody.split(":")) < 2:
|
||||
print(constant)
|
||||
vdtype = vbody.split(":")[-1].strip()
|
||||
fixed_vdtype = vdtype
|
||||
if "c1_i64" in vname:
|
||||
print(constant)
|
||||
counter += 1
|
||||
if counter == 2:
|
||||
counter = 0
|
||||
print("detected duplicate")
|
||||
continue
|
||||
vnames.append(vname)
|
||||
if "true" not in vname:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
else:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : i1"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
|
||||
new_f1, new_f2 = [], []
|
||||
|
||||
print(f"[DEBUG] processing f1")
|
||||
for line in f1:
|
||||
if "func.func" in line:
|
||||
new_f1.append(line)
|
||||
for global_var in global_var_loading1:
|
||||
new_f1.append(global_var)
|
||||
else:
|
||||
new_f1.append(line)
|
||||
|
||||
print(f"[DEBUG] processing f2")
|
||||
for line in f2:
|
||||
if "func.func" in line:
|
||||
new_f2.append(line)
|
||||
for global_var in global_var_loading2:
|
||||
if (
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
in global_var
|
||||
):
|
||||
print(global_var)
|
||||
new_f2.append(global_var)
|
||||
else:
|
||||
new_f2.append(line)
|
||||
|
||||
f1 = new_f1
|
||||
f2 = new_f2
|
||||
|
||||
del new_f1
|
||||
del new_f2
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
[
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
|
||||
for x in [maps1, maps2, global_vars, f1, f2]
|
||||
]
|
||||
)
|
||||
|
||||
# doing it this way rather than assembling the whole string
|
||||
# to prevent OOM with 64GiB RAM when encoding the file.
|
||||
|
||||
print(f"[DEBUG] Saving mlir to {output_name}")
|
||||
with open(output_name, "w+") as f_:
|
||||
f_.writelines(line + "\n" for line in maps1)
|
||||
f_.writelines(line + "\n" for line in maps2)
|
||||
f_.writelines(line + "\n" for line in [module_start])
|
||||
f_.writelines(line + "\n" for line in global_vars)
|
||||
f_.writelines(line + "\n" for line in f1)
|
||||
f_.writelines(line + "\n" for line in f2)
|
||||
f_.writelines(line + "\n" for line in [module_end])
|
||||
|
||||
del maps1
|
||||
del maps2
|
||||
del module_start
|
||||
del global_vars
|
||||
del f1
|
||||
del f2
|
||||
del module_end
|
||||
gc.collect()
|
||||
|
||||
if return_ir:
|
||||
print(f"[DEBUG] Reading combined mlir back in")
|
||||
with open(output_name, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def write_in_dynamic_inputs0(module, dynamic_input_size):
|
||||
print("[DEBUG] writing dynamic inputs to first vicuna")
|
||||
# Current solution for ensuring mlir files support dynamic inputs
|
||||
# TODO: find a more elegant way to implement this
|
||||
new_lines = []
|
||||
module = module.splitlines()
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
|
||||
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
|
||||
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
|
||||
continue
|
||||
|
||||
new_lines.append(line)
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def write_in_dynamic_inputs1(module, model_name, precision):
|
||||
print("[DEBUG] writing dynamic inputs to second vicuna")
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "c19_i64" in line:
|
||||
line = re.sub("c19_i64", "dim_i64", line)
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
if "x20x" in line or "<20x" in line:
|
||||
line = re.sub("20x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
|
||||
if " 20," in line:
|
||||
line = re.sub(" 20,", " %dimp1,", line)
|
||||
return line
|
||||
|
||||
module = module.splitlines()
|
||||
new_lines = []
|
||||
|
||||
# Using a while loop and the pop method to avoid creating a copy of module
|
||||
if "llama2_13b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x40x?x128x"
|
||||
elif "llama2_70b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x8x?x128x"
|
||||
else:
|
||||
pkv_tensor_shape = "tensor<1x32x?x128x"
|
||||
if precision in ["fp16", "int4", "int8"]:
|
||||
pkv_tensor_shape += "f16>"
|
||||
else:
|
||||
pkv_tensor_shape += "f32>"
|
||||
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append(
|
||||
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
)
|
||||
continue
|
||||
if "%c2 = arith.constant 2 : index" in line:
|
||||
continue
|
||||
if "%c20_i64 = arith.constant 20 : i64" in line:
|
||||
new_lines.append("%c1_i64 = arith.constant 1 : i64")
|
||||
new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
|
||||
new_lines.append(
|
||||
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
|
||||
)
|
||||
continue
|
||||
line = remove_constant_dim(line)
|
||||
new_lines.append(line)
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def save_dynamic_ir(ir_to_save, output_file):
|
||||
if not ir_to_save:
|
||||
return
|
||||
# We only get string output from the dynamic conversion utility.
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
print(ir_to_save)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="llama ir utility",
|
||||
description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n"
|
||||
+ "\tFollowing are the various ways this script can be used :-\n"
|
||||
+ "\t\ta. To convert a single Linalg IR to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO FIRST IR>\n"
|
||||
+ "\t\tb. To convert two Linalg IRs to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\tc. To combine two Linalg IRs into one:\n"
|
||||
+ "\t\t\t--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n"
|
||||
+ "\t\t\t--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n\n"
|
||||
+ "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n"
|
||||
+ "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n"
|
||||
+ "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n"
|
||||
+ "\t\t\t--precision (DEFAULT: 'int4')\n"
|
||||
+ "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n"
|
||||
+ "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
"-p",
|
||||
default="int4",
|
||||
choices=["fp32", "fp16", "int8", "int4"],
|
||||
help="Precision of the concerned IR",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="llama2_7b",
|
||||
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
|
||||
help="Specify which model to run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first_ir_path",
|
||||
default=None,
|
||||
help="path to first llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--second_ir_path",
|
||||
default=None,
|
||||
help="path to second llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic_input_size",
|
||||
type=int,
|
||||
default=19,
|
||||
help="Specify the static input size to replace with dynamic dim.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Save the individual IR(s) after converting to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--combine",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
dynamic = args.dynamic
|
||||
combine = args.combine
|
||||
assert (
|
||||
dynamic or combine
|
||||
), "neither `dynamic` nor `combine` flag is turned on"
|
||||
first_ir_path = args.first_ir_path
|
||||
second_ir_path = args.second_ir_path
|
||||
assert first_ir_path or second_ir_path, "no input ir has been provided"
|
||||
if combine:
|
||||
assert (
|
||||
first_ir_path and second_ir_path
|
||||
), "you will need to provide both IRs to combine"
|
||||
precision = args.precision
|
||||
model_name = args.model_name
|
||||
dynamic_input_size = args.dynamic_input_size
|
||||
save_dynamic = args.save_dynamic
|
||||
|
||||
print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}")
|
||||
print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}")
|
||||
|
||||
if dynamic and not combine:
|
||||
save_dynamic = True
|
||||
|
||||
first_ir = None
|
||||
first_dynamic_ir_name = None
|
||||
second_ir = None
|
||||
second_dynamic_ir_name = None
|
||||
if first_ir_path:
|
||||
first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic"
|
||||
with open(first_ir_path, "r") as f:
|
||||
first_ir = f.read()
|
||||
if second_ir_path:
|
||||
second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic"
|
||||
with open(second_ir_path, "r") as f:
|
||||
second_ir = f.read()
|
||||
if dynamic:
|
||||
first_ir = (
|
||||
write_in_dynamic_inputs0(first_ir, dynamic_input_size)
|
||||
if first_ir
|
||||
else None
|
||||
)
|
||||
second_ir = (
|
||||
write_in_dynamic_inputs1(second_ir, model_name, precision)
|
||||
if second_ir
|
||||
else None
|
||||
)
|
||||
if save_dynamic:
|
||||
save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir")
|
||||
save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir")
|
||||
|
||||
if combine:
|
||||
combine_mlir_scripts(
|
||||
first_ir,
|
||||
second_ir,
|
||||
f"{model_name}_{precision}.mlir",
|
||||
return_ir=False,
|
||||
)
|
||||
@@ -46,6 +46,7 @@ def compile_stableLM(
|
||||
model_vmfb_name,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug=False,
|
||||
):
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
@@ -92,7 +93,7 @@ def compile_stableLM(
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
94
apps/language_models/shark_llama_cli.spec
Normal file
94
apps/language_models/shark_llama_cli.spec
Normal file
@@ -0,0 +1,94 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += copy_metadata('huggingface-hub')
|
||||
datas += copy_metadata('sentencepiece')
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('py-cpuinfo')
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/vicuna.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_llama_cli',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
675
apps/language_models/src/model_wrappers/falcon_sharded_model.py
Normal file
675
apps/language_models/src/model_wrappers/falcon_sharded_model.py
Normal file
@@ -0,0 +1,675 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class WordEmbeddingsLayer(torch.nn.Module):
|
||||
def __init__(self, word_embedding_layer):
|
||||
super().__init__()
|
||||
self.model = word_embedding_layer
|
||||
|
||||
def forward(self, input_ids):
|
||||
output = self.model.forward(input=input_ids)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledWordEmbeddingsLayer(torch.nn.Module):
|
||||
def __init__(self, compiled_word_embedding_layer):
|
||||
super().__init__()
|
||||
self.model = compiled_word_embedding_layer
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids = input_ids.detach().numpy()
|
||||
new_input_ids = self.model("forward", input_ids)
|
||||
new_input_ids = new_input_ids.reshape(
|
||||
[1, new_input_ids.shape[0], new_input_ids.shape[1]]
|
||||
)
|
||||
return torch.tensor(new_input_ids)
|
||||
|
||||
|
||||
class LNFEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, ln_f):
|
||||
super().__init__()
|
||||
self.model = ln_f
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model.forward(input=hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledLNFEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, ln_f):
|
||||
super().__init__()
|
||||
self.model = ln_f
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach().numpy()
|
||||
new_hidden_states = self.model("forward", (hidden_states,))
|
||||
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class LMHeadEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, embedding_layer):
|
||||
super().__init__()
|
||||
self.model = embedding_layer
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model.forward(input=hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, lm_head):
|
||||
super().__init__()
|
||||
self.model = lm_head
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach().numpy()
|
||||
new_hidden_states = self.model("forward", (hidden_states,))
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class FourWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
self.falcon_variant = falcon_variant
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
new_pkvs = []
|
||||
for layer in self.model:
|
||||
outputs = layer(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision, model
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.to(torch.float32).detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class TwoWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
self.falcon_variant = falcon_variant
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
new_pkvs = []
|
||||
for layer in self.model:
|
||||
outputs = layer(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
(new_pkv200, new_pkv201),
|
||||
(new_pkv210, new_pkv211),
|
||||
(new_pkv220, new_pkv221),
|
||||
(new_pkv230, new_pkv231),
|
||||
(new_pkv240, new_pkv241),
|
||||
(new_pkv250, new_pkv251),
|
||||
(new_pkv260, new_pkv261),
|
||||
(new_pkv270, new_pkv271),
|
||||
(new_pkv280, new_pkv281),
|
||||
(new_pkv290, new_pkv291),
|
||||
(new_pkv300, new_pkv301),
|
||||
(new_pkv310, new_pkv311),
|
||||
(new_pkv320, new_pkv321),
|
||||
(new_pkv330, new_pkv331),
|
||||
(new_pkv340, new_pkv341),
|
||||
(new_pkv350, new_pkv351),
|
||||
(new_pkv360, new_pkv361),
|
||||
(new_pkv370, new_pkv371),
|
||||
(new_pkv380, new_pkv381),
|
||||
(new_pkv390, new_pkv391),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
new_pkv200,
|
||||
new_pkv201,
|
||||
new_pkv210,
|
||||
new_pkv211,
|
||||
new_pkv220,
|
||||
new_pkv221,
|
||||
new_pkv230,
|
||||
new_pkv231,
|
||||
new_pkv240,
|
||||
new_pkv241,
|
||||
new_pkv250,
|
||||
new_pkv251,
|
||||
new_pkv260,
|
||||
new_pkv261,
|
||||
new_pkv270,
|
||||
new_pkv271,
|
||||
new_pkv280,
|
||||
new_pkv281,
|
||||
new_pkv290,
|
||||
new_pkv291,
|
||||
new_pkv300,
|
||||
new_pkv301,
|
||||
new_pkv310,
|
||||
new_pkv311,
|
||||
new_pkv320,
|
||||
new_pkv321,
|
||||
new_pkv330,
|
||||
new_pkv331,
|
||||
new_pkv340,
|
||||
new_pkv341,
|
||||
new_pkv350,
|
||||
new_pkv351,
|
||||
new_pkv360,
|
||||
new_pkv361,
|
||||
new_pkv370,
|
||||
new_pkv371,
|
||||
new_pkv380,
|
||||
new_pkv381,
|
||||
new_pkv390,
|
||||
new_pkv391,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision, model
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.to(torch.float32).detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[41]),
|
||||
torch.tensor(output[42]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[43]),
|
||||
torch.tensor(output[44]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[45]),
|
||||
torch.tensor(output[46]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[47]),
|
||||
torch.tensor(output[48]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[49]),
|
||||
torch.tensor(output[50]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[51]),
|
||||
torch.tensor(output[52]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[53]),
|
||||
torch.tensor(output[54]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[55]),
|
||||
torch.tensor(output[56]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[57]),
|
||||
torch.tensor(output[58]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[59]),
|
||||
torch.tensor(output[60]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[61]),
|
||||
torch.tensor(output[62]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[63]),
|
||||
torch.tensor(output[64]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[65]),
|
||||
torch.tensor(output[66]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[67]),
|
||||
torch.tensor(output[68]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[69]),
|
||||
torch.tensor(output[70]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[71]),
|
||||
torch.tensor(output[72]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[73]),
|
||||
torch.tensor(output[74]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[75]),
|
||||
torch.tensor(output[76]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[77]),
|
||||
torch.tensor(output[78]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[79]),
|
||||
torch.tensor(output[80]),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class ShardedFalconModel:
|
||||
def __init__(self, model, layers, word_embeddings, ln_f, lm_head):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.model.transformer.h = torch.nn.modules.container.ModuleList(
|
||||
layers
|
||||
)
|
||||
self.model.transformer.word_embeddings = word_embeddings
|
||||
self.model.transformer.ln_f = ln_f
|
||||
self.model.lm_head = lm_head
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
):
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
).logits[:, -1, :]
|
||||
@@ -5,7 +5,7 @@ from typing import List, Any
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class VisionModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -52,7 +52,7 @@ class VisionModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -93,7 +93,7 @@ class FirstLlamaModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -157,7 +157,7 @@ class SecondLlamaModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
|
||||
@@ -47,18 +47,15 @@ from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna,
|
||||
SecondVicuna7B,
|
||||
)
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import get_f16_inputs
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
class FirstVicuna(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -18,18 +16,29 @@ class FirstVicuna(torch.nn.Module):
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("First Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float32,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -40,7 +49,9 @@ class FirstVicuna(torch.nn.Module):
|
||||
def forward(self, input_ids):
|
||||
op = self.model(input_ids=input_ids, use_cache=True)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
token = torch.argmax(op.logits[:, -1, :], dim=1)
|
||||
return_vals.append(token)
|
||||
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
@@ -48,11 +59,12 @@ class FirstVicuna(torch.nn.Module):
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna(torch.nn.Module):
|
||||
class SecondVicuna7B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -64,15 +76,26 @@ class SecondVicuna(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float32,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -148,8 +171,6 @@ class SecondVicuna(torch.nn.Module):
|
||||
i63,
|
||||
i64,
|
||||
):
|
||||
# input_ids = input_tuple[0]
|
||||
# input_tuple = torch.unbind(pkv, dim=0)
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
@@ -282,6 +303,846 @@ class SecondVicuna(torch.nn.Module):
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
token = torch.argmax(op.logits[:, -1, :], dim=1)
|
||||
return_vals.append(token)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna13B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="int8",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
i0,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
i65,
|
||||
i66,
|
||||
i67,
|
||||
i68,
|
||||
i69,
|
||||
i70,
|
||||
i71,
|
||||
i72,
|
||||
i73,
|
||||
i74,
|
||||
i75,
|
||||
i76,
|
||||
i77,
|
||||
i78,
|
||||
i79,
|
||||
i80,
|
||||
):
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
(
|
||||
i65,
|
||||
i66,
|
||||
),
|
||||
(
|
||||
i67,
|
||||
i68,
|
||||
),
|
||||
(
|
||||
i69,
|
||||
i70,
|
||||
),
|
||||
(
|
||||
i71,
|
||||
i72,
|
||||
),
|
||||
(
|
||||
i73,
|
||||
i74,
|
||||
),
|
||||
(
|
||||
i75,
|
||||
i76,
|
||||
),
|
||||
(
|
||||
i77,
|
||||
i78,
|
||||
),
|
||||
(
|
||||
i79,
|
||||
i80,
|
||||
),
|
||||
)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna70B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
i0,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
i65,
|
||||
i66,
|
||||
i67,
|
||||
i68,
|
||||
i69,
|
||||
i70,
|
||||
i71,
|
||||
i72,
|
||||
i73,
|
||||
i74,
|
||||
i75,
|
||||
i76,
|
||||
i77,
|
||||
i78,
|
||||
i79,
|
||||
i80,
|
||||
i81,
|
||||
i82,
|
||||
i83,
|
||||
i84,
|
||||
i85,
|
||||
i86,
|
||||
i87,
|
||||
i88,
|
||||
i89,
|
||||
i90,
|
||||
i91,
|
||||
i92,
|
||||
i93,
|
||||
i94,
|
||||
i95,
|
||||
i96,
|
||||
i97,
|
||||
i98,
|
||||
i99,
|
||||
i100,
|
||||
i101,
|
||||
i102,
|
||||
i103,
|
||||
i104,
|
||||
i105,
|
||||
i106,
|
||||
i107,
|
||||
i108,
|
||||
i109,
|
||||
i110,
|
||||
i111,
|
||||
i112,
|
||||
i113,
|
||||
i114,
|
||||
i115,
|
||||
i116,
|
||||
i117,
|
||||
i118,
|
||||
i119,
|
||||
i120,
|
||||
i121,
|
||||
i122,
|
||||
i123,
|
||||
i124,
|
||||
i125,
|
||||
i126,
|
||||
i127,
|
||||
i128,
|
||||
i129,
|
||||
i130,
|
||||
i131,
|
||||
i132,
|
||||
i133,
|
||||
i134,
|
||||
i135,
|
||||
i136,
|
||||
i137,
|
||||
i138,
|
||||
i139,
|
||||
i140,
|
||||
i141,
|
||||
i142,
|
||||
i143,
|
||||
i144,
|
||||
i145,
|
||||
i146,
|
||||
i147,
|
||||
i148,
|
||||
i149,
|
||||
i150,
|
||||
i151,
|
||||
i152,
|
||||
i153,
|
||||
i154,
|
||||
i155,
|
||||
i156,
|
||||
i157,
|
||||
i158,
|
||||
i159,
|
||||
i160,
|
||||
):
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
(
|
||||
i65,
|
||||
i66,
|
||||
),
|
||||
(
|
||||
i67,
|
||||
i68,
|
||||
),
|
||||
(
|
||||
i69,
|
||||
i70,
|
||||
),
|
||||
(
|
||||
i71,
|
||||
i72,
|
||||
),
|
||||
(
|
||||
i73,
|
||||
i74,
|
||||
),
|
||||
(
|
||||
i75,
|
||||
i76,
|
||||
),
|
||||
(
|
||||
i77,
|
||||
i78,
|
||||
),
|
||||
(
|
||||
i79,
|
||||
i80,
|
||||
),
|
||||
(
|
||||
i81,
|
||||
i82,
|
||||
),
|
||||
(
|
||||
i83,
|
||||
i84,
|
||||
),
|
||||
(
|
||||
i85,
|
||||
i86,
|
||||
),
|
||||
(
|
||||
i87,
|
||||
i88,
|
||||
),
|
||||
(
|
||||
i89,
|
||||
i90,
|
||||
),
|
||||
(
|
||||
i91,
|
||||
i92,
|
||||
),
|
||||
(
|
||||
i93,
|
||||
i94,
|
||||
),
|
||||
(
|
||||
i95,
|
||||
i96,
|
||||
),
|
||||
(
|
||||
i97,
|
||||
i98,
|
||||
),
|
||||
(
|
||||
i99,
|
||||
i100,
|
||||
),
|
||||
(
|
||||
i101,
|
||||
i102,
|
||||
),
|
||||
(
|
||||
i103,
|
||||
i104,
|
||||
),
|
||||
(
|
||||
i105,
|
||||
i106,
|
||||
),
|
||||
(
|
||||
i107,
|
||||
i108,
|
||||
),
|
||||
(
|
||||
i109,
|
||||
i110,
|
||||
),
|
||||
(
|
||||
i111,
|
||||
i112,
|
||||
),
|
||||
(
|
||||
i113,
|
||||
i114,
|
||||
),
|
||||
(
|
||||
i115,
|
||||
i116,
|
||||
),
|
||||
(
|
||||
i117,
|
||||
i118,
|
||||
),
|
||||
(
|
||||
i119,
|
||||
i120,
|
||||
),
|
||||
(
|
||||
i121,
|
||||
i122,
|
||||
),
|
||||
(
|
||||
i123,
|
||||
i124,
|
||||
),
|
||||
(
|
||||
i125,
|
||||
i126,
|
||||
),
|
||||
(
|
||||
i127,
|
||||
i128,
|
||||
),
|
||||
(
|
||||
i129,
|
||||
i130,
|
||||
),
|
||||
(
|
||||
i131,
|
||||
i132,
|
||||
),
|
||||
(
|
||||
i133,
|
||||
i134,
|
||||
),
|
||||
(
|
||||
i135,
|
||||
i136,
|
||||
),
|
||||
(
|
||||
i137,
|
||||
i138,
|
||||
),
|
||||
(
|
||||
i139,
|
||||
i140,
|
||||
),
|
||||
(
|
||||
i141,
|
||||
i142,
|
||||
),
|
||||
(
|
||||
i143,
|
||||
i144,
|
||||
),
|
||||
(
|
||||
i145,
|
||||
i146,
|
||||
),
|
||||
(
|
||||
i147,
|
||||
i148,
|
||||
),
|
||||
(
|
||||
i149,
|
||||
i150,
|
||||
),
|
||||
(
|
||||
i151,
|
||||
i152,
|
||||
),
|
||||
(
|
||||
i153,
|
||||
i154,
|
||||
),
|
||||
(
|
||||
i155,
|
||||
i156,
|
||||
),
|
||||
(
|
||||
i157,
|
||||
i158,
|
||||
),
|
||||
(
|
||||
i159,
|
||||
i160,
|
||||
),
|
||||
)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
@@ -298,7 +1159,8 @@ class CombinedModel(torch.nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
|
||||
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
|
||||
# NOT using this path for 13B currently, hence using `SecondVicuna7B`.
|
||||
self.second_vicuna = SecondVicuna7B(second_vicuna_model_path)
|
||||
|
||||
def forward(self, input_ids):
|
||||
first_output = self.first_vicuna(input_ids=input_ids)
|
||||
|
||||
1173
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
1173
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import time
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
@@ -110,9 +111,11 @@ class LMHeadCompiled(torch.nn.Module):
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach()
|
||||
hidden_states_sample = hidden_states.detach()
|
||||
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -136,8 +139,9 @@ class VicunaNormCompiled(torch.nn.Module):
|
||||
hidden_states.detach()
|
||||
except:
|
||||
pass
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = self.model("forward", (hidden_states,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -158,8 +162,9 @@ class VicunaEmbeddingCompiled(torch.nn.Module):
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids.detach()
|
||||
output = self.model("forward", (input_ids,))
|
||||
output = self.model("forward", (input_ids,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -178,9 +183,10 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
use_cache=True,
|
||||
):
|
||||
if past_key_value is None:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
# hidden_states = hidden_states.detach()
|
||||
# attention_mask = attention_mask.detach()
|
||||
# position_ids = position_ids.detach()
|
||||
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
(
|
||||
@@ -188,11 +194,17 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
send_to_host=True,
|
||||
)
|
||||
|
||||
### send_to_host=True
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
### send_to_host=False
|
||||
# output0 = output[0]
|
||||
# output1 = output[1]
|
||||
# output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
@@ -202,11 +214,12 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
# hidden_states = hidden_states.detach()
|
||||
# attention_mask = attention_mask.detach()
|
||||
# position_ids = position_ids.detach()
|
||||
# pkv0 = past_key_value[0].detach()
|
||||
pkv0 = past_key_value[0]
|
||||
pkv1 = past_key_value[1]
|
||||
output = self.model(
|
||||
"second_vicuna_forward",
|
||||
(
|
||||
@@ -216,11 +229,17 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
send_to_host=True,
|
||||
)
|
||||
|
||||
### send_to_host=True
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
### send_to_host=False
|
||||
# output0 = output[0]
|
||||
# output1 = output[1]
|
||||
# output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
|
||||
@@ -3,7 +3,10 @@ from abc import ABC, abstractmethod
|
||||
|
||||
class SharkLLMBase(ABC):
|
||||
def __init__(
|
||||
self, model_name, hf_model_path=None, max_num_tokens=512
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path=None,
|
||||
max_num_tokens=512,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.hf_model_path = hf_model_path
|
||||
|
||||
@@ -1,4 +1,17 @@
|
||||
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
|
||||
from apps.language_models.src.model_wrappers.falcon_sharded_model import (
|
||||
WordEmbeddingsLayer,
|
||||
CompiledWordEmbeddingsLayer,
|
||||
LNFEmbeddingLayer,
|
||||
CompiledLNFEmbeddingLayer,
|
||||
LMHeadEmbeddingLayer,
|
||||
CompiledLMHeadEmbeddingLayer,
|
||||
FourWayShardingDecoderLayer,
|
||||
TwoWayShardingDecoderLayer,
|
||||
CompiledFourWayShardingDecoderLayer,
|
||||
CompiledTwoWayShardingDecoderLayer,
|
||||
ShardedFalconModel,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
@@ -7,30 +20,39 @@ from io import BytesIO
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stdout
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
|
||||
from transformers.generation import (
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
import copy
|
||||
|
||||
import time
|
||||
import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
import argparse
|
||||
import gc
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="falcon runner",
|
||||
description="runs a falcon model",
|
||||
)
|
||||
|
||||
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
|
||||
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compressed",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Do the compression of sharded layers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
@@ -49,7 +71,7 @@ parser.add_argument(
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=False,
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
@@ -59,32 +81,74 @@ parser.add_argument(
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model in cli mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication token for falcon-180B model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--sharded",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model as sharded",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_shards",
|
||||
type=int,
|
||||
default=4,
|
||||
choices=[2, 4],
|
||||
help="Number of shards.",
|
||||
)
|
||||
|
||||
|
||||
class Falcon(SharkLLMBase):
|
||||
class ShardedFalcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path,
|
||||
hf_model_path="tiiuae/falcon-7b-instruct",
|
||||
hf_auth_token: str = None,
|
||||
max_num_tokens=150,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print("hf_model_path: ", self.hf_model_path)
|
||||
|
||||
if (
|
||||
"180b" in self.model_name
|
||||
and precision != "int4"
|
||||
and hf_auth_token == None
|
||||
):
|
||||
raise ValueError(
|
||||
""" HF auth token required for falcon-180b. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
|
||||
if args.sharded and "180b" not in self.model_name:
|
||||
raise ValueError("Sharding supported only for Falcon-180B")
|
||||
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.falcon_vmfb_path = falcon_vmfb_path
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, trust_remote_code=True
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
token=self.hf_auth_token,
|
||||
)
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token_id = 11
|
||||
@@ -92,13 +156,535 @@ class Falcon(SharkLLMBase):
|
||||
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float32,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
if self.precision == "int4":
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["device_map"] = "cpu"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
return falcon_model
|
||||
|
||||
def compile_falcon(self):
|
||||
def compile_layer(
|
||||
self, layer, falconCompileInput, layer_id, device_idx=None
|
||||
):
|
||||
self.falcon_mlir_path = Path(
|
||||
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir"
|
||||
)
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
print(f"[DEBUG] Trying to download vmfb from shark_tank")
|
||||
download_public_file(
|
||||
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/vmfb/"
|
||||
+ str(self.falcon_vmfb_path),
|
||||
self.falcon_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=device_idx,
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb, device_idx
|
||||
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
if args.load_mlir_from_shark_tank:
|
||||
# Downloading MLIR from shark_tank
|
||||
print(f"[DEBUG] Trying to download mlir from shark_tank")
|
||||
download_public_file(
|
||||
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/mlir/"
|
||||
+ str(self.falcon_mlir_path),
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(
|
||||
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] generating MLIR locally")
|
||||
if layer_id == "word_embeddings":
|
||||
f16_input_mask = [False]
|
||||
elif layer_id in ["ln_f", "lm_head"]:
|
||||
f16_input_mask = [True]
|
||||
elif "_" in layer_id or type(layer_id) == int:
|
||||
f16_input_mask = [True, True]
|
||||
else:
|
||||
raise ValueError("Unsupported layer: ", layer_id)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
layer,
|
||||
falconCompileInput,
|
||||
is_f16=True,
|
||||
f16_input_mask=f16_input_mask,
|
||||
mlir_type="torchscript",
|
||||
is_gptq=True,
|
||||
)
|
||||
del layer
|
||||
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
falconCompileInput,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
f_ = open(self.falcon_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
|
||||
f_.close()
|
||||
del bytecode
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=self.falcon_mlir_path,
|
||||
device=self.device,
|
||||
mlir_dialect="linalg",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
self.falcon_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
+ [
|
||||
"--iree-llvmcpu-use-fast-min-max-ops",
|
||||
]
|
||||
if self.precision == "int4"
|
||||
else [],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module, device_idx
|
||||
|
||||
def compile(self):
|
||||
sample_input_ids = torch.zeros([100], dtype=torch.int64)
|
||||
sample_attention_mask = torch.zeros(
|
||||
[1, 1, 100, 100], dtype=torch.float32
|
||||
)
|
||||
num_group_layers = int(
|
||||
20 * (4 / args.num_shards)
|
||||
) # 4 is the number of default shards
|
||||
sample_hidden_states = torch.zeros(
|
||||
[1, 100, 14848], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Determine number of available devices
|
||||
num_devices = 1
|
||||
if self.device == "rocm":
|
||||
import iree.runtime as ireert
|
||||
|
||||
haldriver = ireert.get_driver(self.device)
|
||||
num_devices = len(haldriver.query_available_devices())
|
||||
if num_devices < 2:
|
||||
raise ValueError(
|
||||
"Cannot run Falcon-180B on a single ROCM device."
|
||||
)
|
||||
|
||||
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
|
||||
print("Compiling Layer lm_head")
|
||||
shark_lm_head, _ = self.compile_layer(
|
||||
lm_head,
|
||||
[sample_hidden_states],
|
||||
"lm_head",
|
||||
device_idx=(0 % num_devices) % args.num_shards
|
||||
if self.device == "rocm"
|
||||
else None,
|
||||
)
|
||||
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
|
||||
|
||||
word_embedding = WordEmbeddingsLayer(
|
||||
self.src_model.transformer.word_embeddings
|
||||
)
|
||||
print("Compiling Layer word_embeddings")
|
||||
shark_word_embedding, _ = self.compile_layer(
|
||||
word_embedding,
|
||||
[sample_input_ids],
|
||||
"word_embeddings",
|
||||
device_idx=(1 % num_devices) % args.num_shards
|
||||
if self.device == "rocm"
|
||||
else None,
|
||||
)
|
||||
shark_word_embedding = CompiledWordEmbeddingsLayer(
|
||||
shark_word_embedding
|
||||
)
|
||||
|
||||
ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f)
|
||||
print("Compiling Layer ln_f")
|
||||
shark_ln_f, _ = self.compile_layer(
|
||||
ln_f,
|
||||
[sample_hidden_states],
|
||||
"ln_f",
|
||||
device_idx=(2 % num_devices) % args.num_shards
|
||||
if self.device == "rocm"
|
||||
else None,
|
||||
)
|
||||
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
|
||||
|
||||
shark_layers = []
|
||||
for i in range(
|
||||
int(len(self.src_model.transformer.h) / num_group_layers)
|
||||
):
|
||||
device_idx = i % num_devices if self.device == "rocm" else None
|
||||
layer_id = i
|
||||
layer_id = (
|
||||
str(i * num_group_layers)
|
||||
+ "_"
|
||||
+ str((i + 1) * num_group_layers)
|
||||
)
|
||||
pytorch_class = FourWayShardingDecoderLayer
|
||||
compiled_class = CompiledFourWayShardingDecoderLayer
|
||||
if args.num_shards == 2:
|
||||
pytorch_class = TwoWayShardingDecoderLayer
|
||||
compiled_class = CompiledTwoWayShardingDecoderLayer
|
||||
|
||||
print("Compiling Layer {}".format(layer_id))
|
||||
layer_i = self.src_model.transformer.h[
|
||||
i * num_group_layers : (i + 1) * num_group_layers
|
||||
]
|
||||
|
||||
pytorch_layer_i = pytorch_class(
|
||||
layer_i, args.falcon_variant_to_use
|
||||
)
|
||||
shark_module, device_idx = self.compile_layer(
|
||||
pytorch_layer_i,
|
||||
[sample_hidden_states, sample_attention_mask],
|
||||
layer_id,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
shark_layer_i = compiled_class(
|
||||
layer_id,
|
||||
device_idx,
|
||||
args.falcon_variant_to_use,
|
||||
self.device,
|
||||
self.precision,
|
||||
shark_module,
|
||||
)
|
||||
shark_layers.append(shark_layer_i)
|
||||
|
||||
sharded_model = ShardedFalconModel(
|
||||
self.src_model,
|
||||
shark_layers,
|
||||
shark_word_embedding,
|
||||
shark_ln_f,
|
||||
shark_lm_head,
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
def generate(self, prompt):
|
||||
model_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.max_padding_length,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
model_inputs["prompt_text"] = prompt
|
||||
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs.get("attention_mask", None)
|
||||
|
||||
# Allow empty prompts
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
|
||||
generate_kwargs = {
|
||||
"max_length": self.max_num_tokens,
|
||||
"do_sample": True,
|
||||
"top_k": 10,
|
||||
"num_return_sequences": 1,
|
||||
"eos_token_id": 11,
|
||||
}
|
||||
generate_kwargs["input_ids"] = input_ids
|
||||
generate_kwargs["attention_mask"] = attention_mask
|
||||
generation_config_ = GenerationConfig.from_model_config(
|
||||
self.src_model.config
|
||||
)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
model_kwargs,
|
||||
) = self.src_model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs[
|
||||
"output_hidden_states"
|
||||
] = generation_config.output_hidden_states
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
self.logits_processor = self.src_model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids.shape[-1],
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
self.stopping_criteria = self.src_model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
self.logits_warper = self.src_model._get_logits_warper(
|
||||
generation_config
|
||||
)
|
||||
|
||||
(
|
||||
self.input_ids,
|
||||
self.model_kwargs,
|
||||
) = self.src_model._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id) if eos_token_id is not None else None
|
||||
)
|
||||
|
||||
self.pad_token_id = generation_config.pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
self.scores = (
|
||||
() if (return_dict_in_generate and output_scores) else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
self.unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
|
||||
all_text = prompt
|
||||
|
||||
start = time.time()
|
||||
count = 0
|
||||
for i in range(self.max_num_tokens - 1):
|
||||
count = count + 1
|
||||
|
||||
next_token = self.generate_new_token()
|
||||
new_word = self.tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
all_text = all_text + new_word
|
||||
|
||||
print(f"{new_word}", end="", flush=True)
|
||||
print(f"{all_text}", end="", flush=True)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if self.eos_token_id_tensor is not None:
|
||||
self.unfinished_sequences = self.unfinished_sequences.mul(
|
||||
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
|
||||
.ne(self.eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if (
|
||||
self.unfinished_sequences.max() == 0
|
||||
or self.stopping_criteria(input_ids, self.scores)
|
||||
):
|
||||
break
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / count
|
||||
)
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return all_text
|
||||
|
||||
def generate_new_token(self):
|
||||
model_inputs = self.src_model.prepare_inputs_for_generation(
|
||||
self.input_ids, **self.model_kwargs
|
||||
)
|
||||
outputs = self.shark_model.forward(
|
||||
input_ids=model_inputs["input_ids"],
|
||||
attention_mask=model_inputs["attention_mask"],
|
||||
)
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = self.logits_processor(
|
||||
self.input_ids, next_token_logits
|
||||
)
|
||||
next_token_scores = self.logits_warper(
|
||||
self.input_ids, next_token_scores
|
||||
)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if self.eos_token_id is not None:
|
||||
if self.pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = (
|
||||
next_token * self.unfinished_sequences
|
||||
+ self.pad_token_id * (1 - self.unfinished_sequences)
|
||||
)
|
||||
|
||||
self.input_ids = torch.cat(
|
||||
[self.input_ids, next_token[:, None]], dim=-1
|
||||
)
|
||||
|
||||
self.model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in self.model_kwargs:
|
||||
attention_mask = self.model_kwargs["attention_mask"]
|
||||
self.model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
self.input_ids = self.input_ids[:, 1:]
|
||||
self.model_kwargs["attention_mask"] = self.model_kwargs[
|
||||
"attention_mask"
|
||||
][:, 1:]
|
||||
|
||||
return next_token
|
||||
|
||||
|
||||
class UnshardedFalcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="tiiuae/falcon-7b-instruct",
|
||||
hf_auth_token: str = "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",
|
||||
max_num_tokens=150,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print("hf_model_path: ", self.hf_model_path)
|
||||
|
||||
if "180b" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
""" HF auth token required for falcon-180b. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.falcon_vmfb_path = falcon_vmfb_path
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
token=self.hf_auth_token,
|
||||
)
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token_id = 11
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float32,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
if self.precision == "int4":
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["device_map"] = "cpu"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
return falcon_model
|
||||
|
||||
def compile(self):
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
@@ -120,37 +706,37 @@ class Falcon(SharkLLMBase):
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
|
||||
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
if args.load_mlir_from_shark_tank:
|
||||
# Downloading MLIR from shark_tank
|
||||
print(f"[DEBUG] Trying to download mlir from shark_tank")
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(
|
||||
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
mlir_generated = True
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] generating MLIR locally")
|
||||
compilation_input_ids = torch.randint(
|
||||
low=1, high=10000, size=(1, 100)
|
||||
)
|
||||
@@ -167,9 +753,10 @@ class Falcon(SharkLLMBase):
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
falconCompileInput,
|
||||
is_f16=self.precision == "fp16",
|
||||
is_f16=self.precision in ["fp16", "int4"],
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
is_gptq=self.precision == "int4",
|
||||
)
|
||||
del model
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
@@ -189,35 +776,37 @@ class Falcon(SharkLLMBase):
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
print(f"[DEBUG] writing mlir to file")
|
||||
with open(f"{self.model_name}.mlir", "wb") as f_:
|
||||
with redirect_stdout(f_):
|
||||
print(module.operation.get_asm())
|
||||
f_ = open(self.falcon_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
|
||||
f_.close()
|
||||
del bytecode
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module=self.falcon_mlir_path,
|
||||
device=self.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
self.falcon_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-spirv-index-bits=64",
|
||||
],
|
||||
]
|
||||
+ [
|
||||
"--iree-llvmcpu-use-fast-min-max-ops",
|
||||
]
|
||||
if self.precision == "int4"
|
||||
else [],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile(self):
|
||||
falcon_shark_model = self.compile_falcon()
|
||||
return falcon_shark_model
|
||||
|
||||
def generate(self, prompt):
|
||||
model_inputs = self.tokenizer(
|
||||
prompt,
|
||||
@@ -345,7 +934,11 @@ class Falcon(SharkLLMBase):
|
||||
|
||||
all_text = prompt
|
||||
|
||||
start = time.time()
|
||||
count = 0
|
||||
for i in range(self.max_num_tokens - 1):
|
||||
count = count + 1
|
||||
|
||||
next_token = self.generate_new_token()
|
||||
new_word = self.tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
@@ -372,6 +965,13 @@ class Falcon(SharkLLMBase):
|
||||
):
|
||||
break
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / count
|
||||
)
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@@ -387,7 +987,7 @@ class Falcon(SharkLLMBase):
|
||||
(model_inputs["input_ids"], model_inputs["attention_mask"]),
|
||||
)
|
||||
)
|
||||
if self.precision == "fp16":
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
@@ -466,18 +1066,39 @@ if __name__ == "__main__":
|
||||
else Path(args.falcon_vmfb_path)
|
||||
)
|
||||
|
||||
falcon = Falcon(
|
||||
"falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path="tiiuae/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
falcon_vmfb_path=falcon_vmfb_path,
|
||||
)
|
||||
if args.precision == "int4":
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"TheBloke/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct-GPTQ"
|
||||
)
|
||||
else:
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "tiiuae/falcon-180B-chat"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
|
||||
)
|
||||
|
||||
import gc
|
||||
if not args.sharded:
|
||||
falcon = UnshardedFalcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
falcon_vmfb_path=falcon_vmfb_path,
|
||||
)
|
||||
else:
|
||||
falcon = ShardedFalcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
)
|
||||
|
||||
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
|
||||
continue_execution = True
|
||||
@@ -497,7 +1118,11 @@ if __name__ == "__main__":
|
||||
prompt = input("Please enter the prompt text: ")
|
||||
print("\nPrompt Text: ", prompt)
|
||||
|
||||
res_str = falcon.generate(prompt)
|
||||
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
|
||||
User: {prompt}
|
||||
Assistant:"""
|
||||
|
||||
res_str = falcon.generate(prompt_template)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
print(
|
||||
|
||||
@@ -126,13 +126,13 @@ def is_url(input_url):
|
||||
import os
|
||||
import tempfile
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from typing import List, Tuple
|
||||
from io import BytesIO
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
|
||||
|
||||
def compile_module(
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[]
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[], debug=False,
|
||||
):
|
||||
if generate_vmfb:
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
|
||||
@@ -190,7 +190,7 @@ def compile_module(
|
||||
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_model_name, extra_args
|
||||
os.getcwd(), extended_model_name, extra_args, debug=debug
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -199,7 +199,7 @@ def compile_module(
|
||||
|
||||
|
||||
def compile_int_precision(
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name, debug=False
|
||||
):
|
||||
torchscript_module = import_with_fx(
|
||||
model,
|
||||
@@ -219,7 +219,7 @@ def compile_int_precision(
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
@@ -235,6 +235,12 @@ def compile_int_precision(
|
||||
mlir_module = BytesIO(mlir_module)
|
||||
bytecode = mlir_module.read()
|
||||
print(f"Elided IR written for {extended_model_name}")
|
||||
bytecode = save_mlir(
|
||||
bytecode,
|
||||
model_name=extended_model_name,
|
||||
frontend="torch",
|
||||
dir=os.getcwd(),
|
||||
)
|
||||
return bytecode
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
@@ -251,6 +257,7 @@ def compile_int_precision(
|
||||
extended_model_name=extended_model_name,
|
||||
generate_vmfb=generate_vmfb,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
),
|
||||
bytecode,
|
||||
)
|
||||
@@ -294,6 +301,7 @@ def shark_compile_through_fx_int(
|
||||
device,
|
||||
generate_or_load_vmfb,
|
||||
extended_model_name,
|
||||
debug,
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
|
||||
@@ -32,11 +32,13 @@ class SharkStableLM(SharkLLMBase):
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug="False",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
@@ -111,7 +113,7 @@ class SharkStableLM(SharkLLMBase):
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from shark.shark_downloader import download_public_file
|
||||
|
||||
# expects a Path / str as arg
|
||||
# returns None if path not found or SharkInference module
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
|
||||
if not isinstance(vmfb_path, Path):
|
||||
vmfb_path = Path(vmfb_path)
|
||||
|
||||
@@ -20,7 +20,7 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
print("Device from get_vmfb_from_path - ", device)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect=mlir_dialect
|
||||
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
@@ -28,7 +28,13 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
|
||||
|
||||
def get_vmfb_from_config(
|
||||
shark_container, model, precision, device, vmfb_path, padding=None
|
||||
shark_container,
|
||||
model,
|
||||
precision,
|
||||
device,
|
||||
vmfb_path,
|
||||
padding=None,
|
||||
device_id=None,
|
||||
):
|
||||
vmfb_url = (
|
||||
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
|
||||
@@ -37,4 +43,6 @@ def get_vmfb_from_config(
|
||||
vmfb_url = vmfb_url + f"_{padding}"
|
||||
vmfb_url = vmfb_url + ".vmfb"
|
||||
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
|
||||
return get_vmfb_from_path(vmfb_path, device, "tm_tensor")
|
||||
return get_vmfb_from_path(
|
||||
vmfb_path, device, "tm_tensor", device_id=device_id
|
||||
)
|
||||
|
||||
91
apps/shark_studio/api/llm.py
Normal file
91
apps/shark_studio/api/llm.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from turbine_models.custom_models import stateless_llama
|
||||
from shark.iree_utils.compile_utils import get_iree_compiled_module
|
||||
from apps.shark_studio.api.utils import get_resource_path
|
||||
import iree.runtime as ireert
|
||||
import gc
|
||||
import torch
|
||||
|
||||
llm_model_map = {
|
||||
"llama2_7b": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"stop_token": 2,
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class LanguageModel:
|
||||
def __init__(
|
||||
self, model_name, hf_auth_token=None, device=None, precision="fp32"
|
||||
):
|
||||
print(llm_model_map[model_name])
|
||||
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
|
||||
self.torch_ir, self.tokenizer = llm_model_map[model_name][
|
||||
"initializer"
|
||||
](self.hf_model_name, hf_auth_token, compile_to="torch")
|
||||
self.tempfile_name = get_resource_path("llm.torch.tempfile")
|
||||
with open(self.tempfile_name, "w+") as f:
|
||||
f.write(self.torch_ir)
|
||||
del self.torch_ir
|
||||
gc.collect()
|
||||
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.max_tokens = llm_model_map[model_name]["max_tokens"]
|
||||
self.iree_module_dict = None
|
||||
self.compile()
|
||||
|
||||
def compile(self) -> None:
|
||||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
|
||||
self.iree_module_dict = get_iree_compiled_module(
|
||||
self.tempfile_name, device=self.device, frontend="torch"
|
||||
)
|
||||
# TODO: delete the temp file
|
||||
|
||||
def chat(self, prompt):
|
||||
history = []
|
||||
for iter in range(self.max_tokens):
|
||||
input_tensor = self.tokenizer(
|
||||
prompt, return_tensors="pt"
|
||||
).input_ids
|
||||
device_inputs = [
|
||||
ireert.asdevicearray(
|
||||
self.iree_module_dict["config"], input_tensor
|
||||
)
|
||||
]
|
||||
if iter == 0:
|
||||
token = torch.tensor(
|
||||
self.iree_module_dict["vmfb"]["run_initialize"](
|
||||
*device_inputs
|
||||
).to_host()[0][0]
|
||||
)
|
||||
else:
|
||||
token = torch.tensor(
|
||||
self.iree_module_dict["vmfb"]["run_forward"](
|
||||
*device_inputs
|
||||
).to_host()[0][0]
|
||||
)
|
||||
|
||||
history.append(token)
|
||||
yield self.tokenizer.decode(history)
|
||||
|
||||
if token == llm_model_map["llama2_7b"]["stop_token"]:
|
||||
break
|
||||
|
||||
for i in range(len(history)):
|
||||
if type(history[i]) != int:
|
||||
history[i] = int(history[i])
|
||||
result_output = self.tokenizer.decode(history)
|
||||
yield result_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
lm = LanguageModel(
|
||||
"llama2_7b",
|
||||
hf_auth_token="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",
|
||||
device="cpu-task",
|
||||
)
|
||||
print("model loaded")
|
||||
for i in lm.chat("Hello, I am a robot."):
|
||||
print(i)
|
||||
14
apps/shark_studio/api/utils.py
Normal file
14
apps/shark_studio/api/utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
return ["cpu-task"]
|
||||
|
||||
|
||||
def get_resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
428
apps/shark_studio/web/index.py
Normal file
428
apps/shark_studio/web/index.py
Normal file
@@ -0,0 +1,428 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from ui.chat import chat_element
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
# import before IREE to avoid MLIR library issues
|
||||
import torch_mlir
|
||||
|
||||
# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
|
||||
# from apps.stable_diffusion.src import args, clear_all
|
||||
# import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
|
||||
def launch_app(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
|
||||
window = Tk()
|
||||
|
||||
# get screen width and height of display and make it more reasonably
|
||||
# sized as we aren't making it full-screen or maximized
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# if args.debug:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
# if args.api or "api" in args.ui.split(","):
|
||||
# from apps.stable_diffusion.web.ui import (
|
||||
# txt2img_api,
|
||||
# img2img_api,
|
||||
# upscaler_api,
|
||||
# inpaint_api,
|
||||
# outpaint_api,
|
||||
# llm_chat_api,
|
||||
# )
|
||||
#
|
||||
# from fastapi import FastAPI, APIRouter
|
||||
# import uvicorn
|
||||
#
|
||||
# # init global sd pipeline and config
|
||||
# global_obj._init()
|
||||
#
|
||||
# app = FastAPI()
|
||||
# app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
|
||||
# app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
|
||||
# app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
|
||||
# app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
|
||||
# app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
|
||||
#
|
||||
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
|
||||
# app.add_api_route(
|
||||
# "/v1/chat/completions", llm_chat_api, methods=["post"]
|
||||
# )
|
||||
# app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
|
||||
# app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
|
||||
# app.add_api_route("/completions", llm_chat_api, methods=["post"])
|
||||
# app.add_api_route(
|
||||
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
|
||||
# )
|
||||
# app.include_router(APIRouter())
|
||||
# uvicorn.run(app, host="0.0.0.0", port=args.server_port)
|
||||
# sys.exit(0)
|
||||
#
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
# from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
# config_gradio_tmp_imgs_folder,
|
||||
# )
|
||||
|
||||
# config_gradio_tmp_imgs_folder()
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
# from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
|
||||
# create_custom_models_folders()
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
# from apps.stable_diffusion.web.ui import (
|
||||
# txt2img_web,
|
||||
# txt2img_custom_model,
|
||||
# txt2img_gallery,
|
||||
# txt2img_png_info_img,
|
||||
# txt2img_status,
|
||||
# txt2img_sendto_img2img,
|
||||
# txt2img_sendto_inpaint,
|
||||
# txt2img_sendto_outpaint,
|
||||
# txt2img_sendto_upscaler,
|
||||
## h2ogpt_upload,
|
||||
## h2ogpt_web,
|
||||
# img2img_web,
|
||||
# img2img_custom_model,
|
||||
# img2img_gallery,
|
||||
# img2img_init_image,
|
||||
# img2img_status,
|
||||
# img2img_sendto_inpaint,
|
||||
# img2img_sendto_outpaint,
|
||||
# img2img_sendto_upscaler,
|
||||
# inpaint_web,
|
||||
# inpaint_custom_model,
|
||||
# inpaint_gallery,
|
||||
# inpaint_init_image,
|
||||
# inpaint_status,
|
||||
# inpaint_sendto_img2img,
|
||||
# inpaint_sendto_outpaint,
|
||||
# inpaint_sendto_upscaler,
|
||||
# outpaint_web,
|
||||
# outpaint_custom_model,
|
||||
# outpaint_gallery,
|
||||
# outpaint_init_image,
|
||||
# outpaint_status,
|
||||
# outpaint_sendto_img2img,
|
||||
# outpaint_sendto_inpaint,
|
||||
# outpaint_sendto_upscaler,
|
||||
# upscaler_web,
|
||||
# upscaler_custom_model,
|
||||
# upscaler_gallery,
|
||||
# upscaler_init_image,
|
||||
# upscaler_status,
|
||||
# upscaler_sendto_img2img,
|
||||
# upscaler_sendto_inpaint,
|
||||
# upscaler_sendto_outpaint,
|
||||
## lora_train_web,
|
||||
## model_web,
|
||||
## model_config_web,
|
||||
# hf_models,
|
||||
# modelmanager_sendto_txt2img,
|
||||
# modelmanager_sendto_img2img,
|
||||
# modelmanager_sendto_inpaint,
|
||||
# modelmanager_sendto_outpaint,
|
||||
# modelmanager_sendto_upscaler,
|
||||
# stablelm_chat,
|
||||
# minigpt4_web,
|
||||
# outputgallery_web,
|
||||
# outputgallery_tab_select,
|
||||
# outputgallery_watch,
|
||||
# outputgallery_filename,
|
||||
# outputgallery_sendto_txt2img,
|
||||
# outputgallery_sendto_img2img,
|
||||
# outputgallery_sendto_inpaint,
|
||||
# outputgallery_sendto_outpaint,
|
||||
# outputgallery_sendto_upscaler,
|
||||
# )
|
||||
|
||||
# init global sd pipeline and config
|
||||
# global_obj._init()
|
||||
|
||||
def register_button_click(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x[0]["name"] if len(x) != 0 else None,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_modelmanager_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
"None",
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_outputgallery_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
|
||||
) as sd_web:
|
||||
with gr.Tabs() as tabs:
|
||||
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
|
||||
# have a unique id that doesn't clash with any of the other tabs,
|
||||
# and that the order in the code here is the order they should
|
||||
# appear in the ui, as the id value doesn't determine the order.
|
||||
|
||||
# Where possible, avoid changing the id of any tab that is the
|
||||
# destination of one of the 'send to' buttons. If you do have to change
|
||||
# that id, make sure you update the relevant register_button_click calls
|
||||
# further down with the new id.
|
||||
# with gr.TabItem(label="Text-to-Image", id=0):
|
||||
# txt2img_web.render()
|
||||
# with gr.TabItem(label="Image-to-Image", id=1):
|
||||
# img2img_web.render()
|
||||
# with gr.TabItem(label="Inpainting", id=2):
|
||||
# inpaint_web.render()
|
||||
# with gr.TabItem(label="Outpainting", id=3):
|
||||
# outpaint_web.render()
|
||||
# with gr.TabItem(label="Upscaler", id=4):
|
||||
# upscaler_web.render()
|
||||
# if args.output_gallery:
|
||||
# with gr.TabItem(label="Output Gallery", id=5) as og_tab:
|
||||
# outputgallery_web.render()
|
||||
|
||||
# # extra output gallery configuration
|
||||
# outputgallery_tab_select(og_tab.select)
|
||||
# outputgallery_watch(
|
||||
# [
|
||||
# txt2img_status,
|
||||
# img2img_status,
|
||||
# inpaint_status,
|
||||
# outpaint_status,
|
||||
# upscaler_status,
|
||||
# ]
|
||||
# )
|
||||
## with gr.TabItem(label="Model Manager", id=6):
|
||||
## model_web.render()
|
||||
## with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
## lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=0):
|
||||
chat_element.render()
|
||||
## with gr.TabItem(
|
||||
## label="Generate Sharding Config (Experimental)", id=9
|
||||
## ):
|
||||
## model_config_web.render()
|
||||
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
# minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
# h2ogpt_upload.render()
|
||||
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
|
||||
# h2ogpt_web.render()
|
||||
|
||||
# send to buttons
|
||||
# register_button_click(
|
||||
# txt2img_sendto_img2img,
|
||||
# 1,
|
||||
# [txt2img_gallery],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# txt2img_sendto_inpaint,
|
||||
# 2,
|
||||
# [txt2img_gallery],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# txt2img_sendto_outpaint,
|
||||
# 3,
|
||||
# [txt2img_gallery],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# txt2img_sendto_upscaler,
|
||||
# 4,
|
||||
# [txt2img_gallery],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# img2img_sendto_inpaint,
|
||||
# 2,
|
||||
# [img2img_gallery],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# img2img_sendto_outpaint,
|
||||
# 3,
|
||||
# [img2img_gallery],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# img2img_sendto_upscaler,
|
||||
# 4,
|
||||
# [img2img_gallery],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# inpaint_sendto_img2img,
|
||||
# 1,
|
||||
# [inpaint_gallery],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# inpaint_sendto_outpaint,
|
||||
# 3,
|
||||
# [inpaint_gallery],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# inpaint_sendto_upscaler,
|
||||
# 4,
|
||||
# [inpaint_gallery],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# outpaint_sendto_img2img,
|
||||
# 1,
|
||||
# [outpaint_gallery],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# outpaint_sendto_inpaint,
|
||||
# 2,
|
||||
# [outpaint_gallery],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# outpaint_sendto_upscaler,
|
||||
# 4,
|
||||
# [outpaint_gallery],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# upscaler_sendto_img2img,
|
||||
# 1,
|
||||
# [upscaler_gallery],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# upscaler_sendto_inpaint,
|
||||
# 2,
|
||||
# [upscaler_gallery],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# upscaler_sendto_outpaint,
|
||||
# 3,
|
||||
# [upscaler_gallery],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# if args.output_gallery:
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_txt2img,
|
||||
# 0,
|
||||
# [outputgallery_filename],
|
||||
# [txt2img_png_info_img, tabs],
|
||||
# )
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_img2img,
|
||||
# 1,
|
||||
# [outputgallery_filename],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_inpaint,
|
||||
# 2,
|
||||
# [outputgallery_filename],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_outpaint,
|
||||
# 3,
|
||||
# [outputgallery_filename],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_upscaler,
|
||||
# 4,
|
||||
# [outputgallery_filename],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_txt2img,
|
||||
# 0,
|
||||
# [hf_models],
|
||||
# [txt2img_custom_model, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_img2img,
|
||||
# 1,
|
||||
# [hf_models],
|
||||
# [img2img_custom_model, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_inpaint,
|
||||
# 2,
|
||||
# [hf_models],
|
||||
# [inpaint_custom_model, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_outpaint,
|
||||
# 3,
|
||||
# [hf_models],
|
||||
# [outpaint_custom_model, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_upscaler,
|
||||
# 4,
|
||||
# [hf_models],
|
||||
# [upscaler_custom_model, tabs],
|
||||
# )
|
||||
|
||||
sd_web.queue()
|
||||
# if args.ui == "app":
|
||||
# t = Process(
|
||||
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
# )
|
||||
# t.start()
|
||||
sd_web.launch(
|
||||
share=True,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=11911, # args.server_port,
|
||||
)
|
||||
0
apps/shark_studio/web/ui/__init__.py
Normal file
0
apps/shark_studio/web/ui/__init__.py
Normal file
517
apps/shark_studio/web/ui/chat.py
Normal file
517
apps/shark_studio/web/ui/chat.py
Normal file
@@ -0,0 +1,517 @@
|
||||
import gradio as gr
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import sys
|
||||
from apps.shark_studio.api.utils import (
|
||||
get_available_devices,
|
||||
)
|
||||
from apps.shark_studio.api.llm import (
|
||||
llm_model_map,
|
||||
LanguageModel,
|
||||
)
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
language_model = None
|
||||
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2_7b": (
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_13b": (
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence "
|
||||
"assistant. The assistant gives helpful, detailed, and "
|
||||
"polite answers to the user's questions.\n"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def create_prompt(model_name, history, prompt_prefix):
|
||||
return ""
|
||||
system_message = ""
|
||||
if prompt_prefix:
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if "llama2" in model_name:
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
conversation = "".join(
|
||||
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
|
||||
)
|
||||
if prompt_prefix:
|
||||
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
else:
|
||||
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
elif model_name in ["vicuna"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
else:
|
||||
conversation = "".join(
|
||||
["".join([item[0], item[1]]) for item in history]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
def get_default_config():
|
||||
return False
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
|
||||
[1, 19]
|
||||
)
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
CombinedModel,
|
||||
)
|
||||
from shark.shark_generate_model_config import GenerateConfigFile
|
||||
|
||||
model = CombinedModel()
|
||||
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
|
||||
c.split_into_layers()
|
||||
|
||||
|
||||
# model_vmfb_key = ""
|
||||
|
||||
|
||||
def chat_fn(
|
||||
prompt_prefix,
|
||||
history,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
cli=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global language_model
|
||||
if language_model is None:
|
||||
language_model = LanguageModel(
|
||||
model, device=device, precision=precision
|
||||
)
|
||||
|
||||
language_model.chat(prompt_prefix)
|
||||
return "", ""
|
||||
global past_key_values
|
||||
global model_vmfb_key
|
||||
|
||||
device_id = None
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
|
||||
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
if device == "vulkan":
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if vulkan_target_triple == "":
|
||||
# We already have the device_id extracted via WebUI, so we directly use
|
||||
# that to find the target triple.
|
||||
vulkan_target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[device_id]
|
||||
)
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={vulkan_target_triple}"
|
||||
)
|
||||
if "rdna" in vulkan_target_triple:
|
||||
flags_to_add = [
|
||||
"--iree-spirv-index-bits=64",
|
||||
]
|
||||
_extra_args = _extra_args + flags_to_add
|
||||
|
||||
if device_id is None:
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[id]
|
||||
)
|
||||
if target_triple == vulkan_target_triple:
|
||||
device_id = id
|
||||
break
|
||||
id += 1
|
||||
|
||||
assert (
|
||||
device_id
|
||||
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
print(f"Will use vulkan target triple : {vulkan_target_triple}")
|
||||
|
||||
elif "rocm" in device:
|
||||
# add iree rocm flags
|
||||
_extra_args.append(
|
||||
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
|
||||
)
|
||||
print(f"extra args = {_extra_args}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=download_vmfb,
|
||||
load_mlir_from_shark_tank=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
if vicuna_model is None:
|
||||
sys.exit("Unable to instantiate the model object, exiting.")
|
||||
|
||||
prompt = create_prompt(model_name, history, prompt_prefix)
|
||||
|
||||
partial_text = ""
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, f"Prefill: {prefill_time:.2f}"
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
|
||||
return history, ""
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
return None
|
||||
print(f"Input keys : {InputData.keys()}")
|
||||
# print(f"model : {InputData['model']}")
|
||||
is_chat_completion_api = (
|
||||
"messages" in InputData.keys()
|
||||
) # else it is the legacy `completion` api
|
||||
# For Debugging input data from API
|
||||
# if is_chat_completion_api:
|
||||
# print(f"message -> role : {InputData['messages'][0]['role']}")
|
||||
# print(f"message -> content : {InputData['messages'][0]['content']}")
|
||||
# else:
|
||||
# print(f"prompt : {InputData['prompt']}")
|
||||
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
|
||||
global vicuna_model
|
||||
model_name = (
|
||||
InputData["model"] if "model" in InputData.keys() else "codegen"
|
||||
)
|
||||
model_path = llm_model_map[model_name]
|
||||
device = "cpu-task"
|
||||
precision = "fp16"
|
||||
max_toks = (
|
||||
None
|
||||
if "max_tokens" not in InputData.keys()
|
||||
else InputData["max_tokens"]
|
||||
)
|
||||
if max_toks is None:
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# make it working for codegen first
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna,
|
||||
)
|
||||
|
||||
device_id = None
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=True,
|
||||
load_mlir_from_shark_tank=True,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
# TODO: add role dict for different models
|
||||
if is_chat_completion_api:
|
||||
# TODO: add funtionality for multiple messages
|
||||
prompt = create_prompt(
|
||||
model_name, [(InputData["messages"][0]["content"], "")]
|
||||
)
|
||||
else:
|
||||
prompt = InputData["prompt"]
|
||||
print("prompt = ", prompt)
|
||||
|
||||
res = vicuna_model.generate(prompt)
|
||||
res_op = None
|
||||
for op in res:
|
||||
res_op = op
|
||||
|
||||
if is_chat_completion_api:
|
||||
choices = [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": res_op, # since we are yeilding the result
|
||||
},
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
else:
|
||||
choices = [
|
||||
{
|
||||
"text": res_op,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
|
||||
return {
|
||||
"id": end_time,
|
||||
"object": "chat.completion"
|
||||
if is_chat_completion_api
|
||||
else "text_completion",
|
||||
"created": int(end_time),
|
||||
"choices": choices,
|
||||
}
|
||||
|
||||
|
||||
def view_json_file(file_obj):
|
||||
content = ""
|
||||
with open(file_obj.name, "r") as fopen:
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
|
||||
with gr.Blocks(title="Chat") as chat_element:
|
||||
with gr.Row():
|
||||
model_choices = list(llm_model_map.keys())
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
supported_devices = get_available_devices()
|
||||
enabled = True
|
||||
if len(supported_devices) == 0:
|
||||
supported_devices = ["cpu-task"]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0],
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="int4",
|
||||
choices=[
|
||||
# "int4",
|
||||
# "int8",
|
||||
# "fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
prompt_prefix = gr.Checkbox(
|
||||
label="Add System Prompt",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
interactive=enabled,
|
||||
container=False,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(
|
||||
label="Upload sharding configuration", visible=False
|
||||
)
|
||||
json_view_button = gr.Button(label="View as JSON", visible=False)
|
||||
json_view = gr.JSON(interactive=True, visible=False)
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
submit_event = msg.submit(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[submit_event, submit_click_event],
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
@@ -7,16 +7,16 @@ Compile Commands FP32/FP16:
|
||||
|
||||
```shell
|
||||
Vulkan AMD:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
|
||||
# use –iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
|
||||
|
||||
CUDA NVIDIA:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
CPU:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu /path/to/input/mlir -o /path/to/output/vmfb
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -105,6 +105,7 @@ def main():
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
control_mode=args.control_mode,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
96
apps/stable_diffusion/scripts/txt2img_sdxl.py
Normal file
96
apps/stable_diffusion/scripts/txt2img_sdxl.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
import time
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImageSDXLPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
|
||||
args.precision = "fp16"
|
||||
args.height = 1024
|
||||
args.width = 1024
|
||||
args.max_length = 77
|
||||
args.scheduler = "DDIM"
|
||||
print(
|
||||
"Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM"
|
||||
)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
start_time = time.time()
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += (
|
||||
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
|
||||
)
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -19,6 +19,9 @@ a = Analysis(
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
module_collection_mode={
|
||||
'gradio': 'py', # Collect gradio package as source .py files
|
||||
},
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@ pathex = [
|
||||
|
||||
# datafiles for pyinstaller
|
||||
datas = []
|
||||
datas += collect_data_files("torch")
|
||||
datas += copy_metadata("torch")
|
||||
datas += copy_metadata("tokenizers")
|
||||
datas += copy_metadata("tqdm")
|
||||
datas += copy_metadata("regex")
|
||||
datas += copy_metadata("requests")
|
||||
@@ -31,20 +31,21 @@ datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += copy_metadata("gradio")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files("diffusers")
|
||||
datas += collect_data_files("transformers")
|
||||
datas += collect_data_files("pytorch_lightning")
|
||||
datas += collect_data_files("opencv_python")
|
||||
datas += collect_data_files("skimage")
|
||||
datas += collect_data_files("gradio")
|
||||
datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree")
|
||||
datas += collect_data_files("google_cloud_storage")
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
@@ -52,6 +53,8 @@ datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
datas += collect_data_files("cv2")
|
||||
datas += collect_data_files("einops")
|
||||
datas += [
|
||||
("src/utils/resources/prompts.json", "resources"),
|
||||
("src/utils/resources/model_db.json", "resources"),
|
||||
@@ -73,8 +76,15 @@ datas += [
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("transformers") if "tests" not in x
|
||||
x for x in collect_submodules("diffusers") if "tests" not in x
|
||||
]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
for x in collect_submodules("transformers")
|
||||
if not any(kw in x for kw in blacklist)
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
hiddenimports += ["iree._runtime", "iree._runtime_libs"]
|
||||
hiddenimports += ["iree._runtime"]
|
||||
|
||||
@@ -9,6 +9,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines import (
|
||||
Text2ImagePipeline,
|
||||
Text2ImageSDXLPipeline,
|
||||
Image2ImagePipeline,
|
||||
InpaintPipeline,
|
||||
OutpaintPipeline,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
|
||||
from transformers import CLIPTextModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import torch
|
||||
@@ -8,6 +8,7 @@ import traceback
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import requests
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_opt_flags,
|
||||
@@ -16,12 +17,15 @@ from apps.stable_diffusion.src.utils import (
|
||||
preprocessCKPT,
|
||||
convert_original_vae,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
get_civitai_checkpoint,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_stencil_model_id,
|
||||
update_lora_weight,
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
|
||||
# These shapes are parameter dependent.
|
||||
@@ -53,6 +57,10 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
|
||||
new_shape.append(math.ceil(height / div_val))
|
||||
elif "width" in shape[i]:
|
||||
new_shape.append(math.ceil(width / div_val))
|
||||
elif "+" in shape[i]:
|
||||
# Currently this case only hits for SDXL. So, in case any other
|
||||
# case requires this operator, change this.
|
||||
new_shape.append(height + width)
|
||||
else:
|
||||
new_shape.append(shape[i])
|
||||
return new_shape
|
||||
@@ -65,6 +73,70 @@ def check_compilation(model, model_name):
|
||||
)
|
||||
|
||||
|
||||
def shark_compile_after_ir(
|
||||
module_name,
|
||||
device,
|
||||
vmfb_path,
|
||||
precision,
|
||||
ir_path=None,
|
||||
):
|
||||
if ir_path:
|
||||
print(f"[DEBUG] mlir found at {ir_path.absolute()}")
|
||||
|
||||
module = SharkInference(
|
||||
mlir_module=ir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
print(f"Will get extra flag for {module_name} and precision = {precision}")
|
||||
path = module.save_module(
|
||||
vmfb_path.parent.absolute(),
|
||||
vmfb_path.stem,
|
||||
extra_args=get_opt_flags(module_name, precision=precision),
|
||||
)
|
||||
print(f"Saved {module_name} vmfb at {path}")
|
||||
module.load_module(path)
|
||||
return module
|
||||
|
||||
|
||||
def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision):
|
||||
name_split = extended_model_name.split("_")
|
||||
if "vae" in model_name:
|
||||
name_split[5] = "fp32"
|
||||
extended_model_name_for_vmfb = "_".join(name_split)
|
||||
extended_model_name_for_mlir = "_".join(name_split[:-1])
|
||||
vmfb_path = Path(extended_model_name_for_vmfb + ".vmfb")
|
||||
if "vulkan" in device:
|
||||
_device = args.iree_vulkan_target_triple
|
||||
_device = _device.replace("-", "_")
|
||||
vmfb_path = Path(extended_model_name_for_vmfb + f"_vulkan.vmfb")
|
||||
if vmfb_path.exists():
|
||||
shark_module = SharkInference(
|
||||
None,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=[])
|
||||
return shark_module, None
|
||||
mlir_path = Path(extended_model_name_for_mlir + ".mlir")
|
||||
if not mlir_path.exists():
|
||||
print(f"Looking into gs://shark_tank/SDXL/mlir/{mlir_path.name}")
|
||||
download_public_file(
|
||||
f"gs://shark_tank/SDXL/mlir/{mlir_path.name}",
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if mlir_path.exists():
|
||||
return (
|
||||
shark_compile_after_ir(
|
||||
model_name, device, vmfb_path, precision, mlir_path
|
||||
),
|
||||
None,
|
||||
)
|
||||
return None, None
|
||||
|
||||
|
||||
class SharkifyStableDiffusionModel:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -84,31 +156,31 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb: bool = True,
|
||||
is_inpaint: bool = False,
|
||||
is_upscaler: bool = False,
|
||||
use_stencil: str = None,
|
||||
is_sdxl: bool = False,
|
||||
stencils: list[str] = [],
|
||||
use_lora: str = "",
|
||||
use_quantize: str = None,
|
||||
return_mlir: bool = False,
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
self.is_sdxl = is_sdxl
|
||||
self.height = height // 8
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
self.custom_weights = custom_weights
|
||||
self.custom_weights = custom_weights.strip()
|
||||
self.use_quantize = use_quantize
|
||||
if custom_weights != "":
|
||||
if "civitai" in custom_weights:
|
||||
weights_id = custom_weights.split("/")[-1]
|
||||
# TODO: use model name and identify file type by civitai rest api
|
||||
weights_path = (
|
||||
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
|
||||
)
|
||||
if not os.path.isfile(weights_path):
|
||||
subprocess.run(
|
||||
["wget", custom_weights, "-O", weights_path]
|
||||
)
|
||||
if custom_weights.startswith("https://civitai.com/api/"):
|
||||
# download the checkpoint from civitai if we don't already have it
|
||||
weights_path = get_civitai_checkpoint(custom_weights)
|
||||
|
||||
# act as if we were given the local file as custom_weights originally
|
||||
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
|
||||
self.custom_weights = weights_path
|
||||
|
||||
# needed to ensure webui sets the correct model name metadata
|
||||
args.ckpt_loc = weights_path
|
||||
else:
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
@@ -116,6 +188,7 @@ class SharkifyStableDiffusionModel:
|
||||
custom_weights = get_path_to_diffusers_checkpoint(
|
||||
custom_weights
|
||||
)
|
||||
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
# TODO: remove the following line when stable-diffusion-2-1 works
|
||||
if self.model_id == "stabilityai/stable-diffusion-2-1":
|
||||
@@ -143,7 +216,7 @@ class SharkifyStableDiffusionModel:
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
self.is_inpaint = is_inpaint
|
||||
self.is_upscaler = is_upscaler
|
||||
self.use_stencil = get_stencil_model_id(use_stencil)
|
||||
self.stencils = [get_stencil_model_id(x) for x in stencils]
|
||||
if use_lora != "":
|
||||
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
|
||||
self.use_lora = use_lora
|
||||
@@ -174,13 +247,15 @@ class SharkifyStableDiffusionModel:
|
||||
model_name = {}
|
||||
sub_model_list = [
|
||||
"clip",
|
||||
"clip2",
|
||||
"unet",
|
||||
"unet512",
|
||||
"stencil_unet",
|
||||
"stencil_unet_512",
|
||||
"vae",
|
||||
"vae_encode",
|
||||
"stencil_adaptor",
|
||||
"stencil_adaptor_512",
|
||||
"stencil_adapter",
|
||||
"stencil_adapter_512",
|
||||
]
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
@@ -193,10 +268,19 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
if self.base_vae:
|
||||
sub_model = "base_vae"
|
||||
if "stencil_adaptor" == model and self.use_stencil is not None:
|
||||
model_config = model_config + get_path_stem(self.use_stencil)
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
index += 1
|
||||
if "stencil_adapter" in model:
|
||||
stencil_names = []
|
||||
for i, stencil in enumerate(self.stencils):
|
||||
if stencil is not None:
|
||||
cnet_config = model_config + stencil.split("_")[-1]
|
||||
stencil_names.append(
|
||||
get_extended_name(sub_model + cnet_config)
|
||||
)
|
||||
|
||||
model_name[model] = stencil_names
|
||||
else:
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
index += 1
|
||||
return model_name
|
||||
|
||||
def check_params(self, max_len, width, height):
|
||||
@@ -340,7 +424,106 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_vae, vae_mlir
|
||||
|
||||
def get_controlled_unet(self):
|
||||
def get_vae_sdxl(self):
|
||||
# TODO: Remove this after convergence with shark_tank. This should just be part of
|
||||
# opt_params.py.
|
||||
shark_module_or_none = process_vmfb_ir_sdxl(
|
||||
self.model_name["vae"], "vae", args.device, self.precision
|
||||
)
|
||||
if shark_module_or_none[0]:
|
||||
return shark_module_or_none
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
base_vae=self.base_vae,
|
||||
custom_vae=self.custom_vae,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = None
|
||||
if custom_vae == "":
|
||||
print(f"Loading default vae, with target {model_id}")
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif not isinstance(custom_vae, dict):
|
||||
precision = "fp16" if "fp16" in custom_vae else None
|
||||
print(f"Loading custom vae, with target {custom_vae}")
|
||||
if os.path.exists(custom_vae):
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
custom_vae = "/".join(
|
||||
[
|
||||
custom_vae.split("/")[-2].split("\\")[-1],
|
||||
custom_vae.split("/")[-1],
|
||||
]
|
||||
)
|
||||
print("Using hub to get custom vae")
|
||||
try:
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
variant=precision,
|
||||
)
|
||||
except:
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
print(f"Loading custom vae, with state {custom_vae}")
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.vae.load_state_dict(custom_vae)
|
||||
self.base_vae = base_vae
|
||||
|
||||
def forward(self, latents):
|
||||
image = self.vae.decode(latents / 0.13025, return_dict=False)[
|
||||
0
|
||||
]
|
||||
return image
|
||||
|
||||
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
|
||||
# pipeline.
|
||||
if not self.custom_vae:
|
||||
is_f16 = False
|
||||
elif "16" in self.custom_vae:
|
||||
is_f16 = True
|
||||
else:
|
||||
is_f16 = False
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
|
||||
if self.debug:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
shark_vae, vae_mlir = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
use_tuned=self.use_tuned,
|
||||
extended_model_name=self.model_name["vae"],
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="vae",
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_vae, vae_mlir
|
||||
|
||||
def get_controlled_unet(self, use_large=False):
|
||||
class ControlledUnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -378,25 +561,54 @@ class SharkifyStableDiffusionModel:
|
||||
control11,
|
||||
control12,
|
||||
control13,
|
||||
scale1,
|
||||
scale2,
|
||||
scale3,
|
||||
scale4,
|
||||
scale5,
|
||||
scale6,
|
||||
scale7,
|
||||
scale8,
|
||||
scale9,
|
||||
scale10,
|
||||
scale11,
|
||||
scale12,
|
||||
scale13,
|
||||
):
|
||||
# TODO: Average pooling
|
||||
db_res_samples = [
|
||||
control1,
|
||||
control2,
|
||||
control3,
|
||||
control4,
|
||||
control5,
|
||||
control6,
|
||||
control7,
|
||||
control8,
|
||||
control9,
|
||||
control10,
|
||||
control11,
|
||||
control12,
|
||||
]
|
||||
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
db_res_samples = tuple(
|
||||
[
|
||||
control1,
|
||||
control2,
|
||||
control3,
|
||||
control4,
|
||||
control5,
|
||||
control6,
|
||||
control7,
|
||||
control8,
|
||||
control9,
|
||||
control10,
|
||||
control11,
|
||||
control12,
|
||||
control1 * scale1,
|
||||
control2 * scale2,
|
||||
control3 * scale3,
|
||||
control4 * scale4,
|
||||
control5 * scale5,
|
||||
control6 * scale6,
|
||||
control7 * scale7,
|
||||
control8 * scale8,
|
||||
control9 * scale9,
|
||||
control10 * scale10,
|
||||
control11 * scale11,
|
||||
control12 * scale12,
|
||||
]
|
||||
)
|
||||
mb_res_samples = control13
|
||||
mb_res_samples = control13 * scale13
|
||||
latents = torch.cat([latent] * 2)
|
||||
unet_out = self.unet.forward(
|
||||
latents,
|
||||
@@ -416,6 +628,16 @@ class SharkifyStableDiffusionModel:
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
model_name = "stencil_unet"
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (
|
||||
inputs[:2]
|
||||
+ (torch.nn.functional.pad(inputs[2], pad),)
|
||||
+ inputs[3:]
|
||||
)
|
||||
model_name = "stencil_unet_512"
|
||||
input_mask = [
|
||||
True,
|
||||
True,
|
||||
@@ -434,33 +656,48 @@ class SharkifyStableDiffusionModel:
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
]
|
||||
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["stencil_unet"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="stencil_unet",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_controlled_unet, controlled_unet_mlir
|
||||
|
||||
def get_control_net(self, use_large=False):
|
||||
def get_control_net(self, stencil_id, use_large=False):
|
||||
stencil_id = get_stencil_model_id(stencil_id)
|
||||
adapter_id, base_model_safe_id, ext_model_name = (None, None, None)
|
||||
print(f"Importing ControlNet adapter from {stencil_id}")
|
||||
|
||||
class StencilControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.use_stencil, low_cpu_mem_usage=False
|
||||
):
|
||||
def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.cnet = ControlNetModel.from_pretrained(
|
||||
model_id,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.in_channels = self.cnet.in_channels
|
||||
self.in_channels = self.cnet.config.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(
|
||||
@@ -469,6 +706,19 @@ class SharkifyStableDiffusionModel:
|
||||
timestep,
|
||||
text_embedding,
|
||||
stencil_image_input,
|
||||
acc1,
|
||||
acc2,
|
||||
acc3,
|
||||
acc4,
|
||||
acc5,
|
||||
acc6,
|
||||
acc7,
|
||||
acc8,
|
||||
acc9,
|
||||
acc10,
|
||||
acc11,
|
||||
acc12,
|
||||
acc13,
|
||||
):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
# TODO: guidance NOT NEEDED change in `get_input_info` later
|
||||
@@ -490,6 +740,20 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return tuple(
|
||||
list(down_block_res_samples) + [mid_block_res_sample]
|
||||
) + (
|
||||
acc1 + down_block_res_samples[0],
|
||||
acc2 + down_block_res_samples[1],
|
||||
acc3 + down_block_res_samples[2],
|
||||
acc4 + down_block_res_samples[3],
|
||||
acc5 + down_block_res_samples[4],
|
||||
acc6 + down_block_res_samples[5],
|
||||
acc7 + down_block_res_samples[6],
|
||||
acc8 + down_block_res_samples[7],
|
||||
acc9 + down_block_res_samples[8],
|
||||
acc10 + down_block_res_samples[9],
|
||||
acc11 + down_block_res_samples[10],
|
||||
acc12 + down_block_res_samples[11],
|
||||
acc13 + mid_block_res_sample,
|
||||
)
|
||||
|
||||
scnet = StencilControlNetModel(
|
||||
@@ -497,7 +761,23 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["stencil_adaptor"])
|
||||
inputs = tuple(self.inputs["stencil_adapter"])
|
||||
model_name = "stencil_adapter_512" if use_large else "stencil_adapter"
|
||||
ext_model_name = self.model_name[model_name]
|
||||
if isinstance(ext_model_name, list):
|
||||
for i in ext_model_name:
|
||||
if stencil_id.split("_")[-1] in i:
|
||||
desired_name = i
|
||||
print(f"Multi-CN: compiling model {i}")
|
||||
else:
|
||||
continue
|
||||
if desired_name:
|
||||
ext_model_name = desired_name
|
||||
else:
|
||||
raise Exception(
|
||||
f"Could not find extended configuration for {stencil_id}"
|
||||
)
|
||||
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
@@ -505,21 +785,15 @@ class SharkifyStableDiffusionModel:
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3],
|
||||
*inputs[3:],
|
||||
)
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
|
||||
)
|
||||
else:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor"]
|
||||
)
|
||||
input_mask = [True, True, True, True]
|
||||
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
|
||||
save_dir = os.path.join(self.sharktank_dir, ext_model_name)
|
||||
input_mask = [True, True, True, True] + ([True] * 13)
|
||||
|
||||
shark_cnet, cnet_mlir = compile_through_fx(
|
||||
scnet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name[model_name],
|
||||
extended_model_name=ext_model_name,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
@@ -676,6 +950,101 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_unet, unet_mlir
|
||||
|
||||
def get_unet_sdxl(self):
|
||||
# TODO: Remove this after convergence with shark_tank. This should just be part of
|
||||
# opt_params.py.
|
||||
shark_module_or_none = process_vmfb_ir_sdxl(
|
||||
self.model_name["unet"], "unet", args.device, self.precision
|
||||
)
|
||||
if shark_module_or_none[0]:
|
||||
return shark_module_or_none
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
super().__init__()
|
||||
try:
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
variant="fp16",
|
||||
)
|
||||
except:
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if (
|
||||
args.attention_slicing is not None
|
||||
and args.attention_slicing != "none"
|
||||
):
|
||||
if args.attention_slicing.isdigit():
|
||||
self.unet.set_attention_slice(
|
||||
int(args.attention_slicing)
|
||||
)
|
||||
else:
|
||||
self.unet.set_attention_slice(args.attention_slicing)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latent,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
text_embeds,
|
||||
time_ids,
|
||||
guidance_scale,
|
||||
):
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": text_embeds,
|
||||
"time_ids": time_ids,
|
||||
}
|
||||
noise_pred = self.unet.forward(
|
||||
latent,
|
||||
timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=None,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
||||
input_mask = [True, True, True, True, True, True]
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["unet"],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="unet",
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_unet, unet_mlir
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -699,8 +1068,11 @@ class SharkifyStableDiffusionModel:
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
|
||||
save_dir = ""
|
||||
if self.debug:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["clip"]
|
||||
)
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
@@ -720,6 +1092,78 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_clip, clip_mlir
|
||||
|
||||
def get_clip_sdxl(self, clip_index=1):
|
||||
if clip_index == 1:
|
||||
extended_model_name = self.model_name["clip"]
|
||||
model_name = "clip"
|
||||
else:
|
||||
extended_model_name = self.model_name["clip2"]
|
||||
model_name = "clip2"
|
||||
# TODO: Remove this after convergence with shark_tank. This should just be part of
|
||||
# opt_params.py.
|
||||
shark_module_or_none = process_vmfb_ir_sdxl(
|
||||
extended_model_name, f"clip", args.device, self.precision
|
||||
)
|
||||
if shark_module_or_none[0]:
|
||||
return shark_module_or_none
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
clip_index=1,
|
||||
):
|
||||
super().__init__()
|
||||
if clip_index == 1:
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
self.text_encoder = (
|
||||
CLIPTextModelWithProjection.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder_2",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
prompt_embeds = self.text_encoder(
|
||||
input,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
clip_model = CLIPText(
|
||||
low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index
|
||||
)
|
||||
save_dir = os.path.join(self.sharktank_dir, extended_model_name)
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
shark_clip, clip_mlir = compile_through_fx(
|
||||
clip_model,
|
||||
tuple(self.inputs["clip"]),
|
||||
extended_model_name=extended_model_name,
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="clip",
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_clip, clip_mlir
|
||||
|
||||
def process_custom_vae(self):
|
||||
custom_vae = self.custom_vae.lower()
|
||||
if not custom_vae.endswith((".ckpt", ".safetensors")):
|
||||
@@ -752,7 +1196,9 @@ class SharkifyStableDiffusionModel:
|
||||
}
|
||||
return vae_dict
|
||||
|
||||
def compile_unet_variants(self, model, use_large=False):
|
||||
def compile_unet_variants(self, model, use_large=False, base_model=""):
|
||||
if self.is_sdxl:
|
||||
return self.get_unet_sdxl()
|
||||
if model == "unet":
|
||||
if self.is_upscaler:
|
||||
return self.get_unet_upscaler(use_large=use_large)
|
||||
@@ -766,7 +1212,7 @@ class SharkifyStableDiffusionModel:
|
||||
else:
|
||||
return self.get_unet(use_large=use_large)
|
||||
else:
|
||||
return self.get_controlled_unet()
|
||||
return self.get_controlled_unet(use_large=use_large)
|
||||
|
||||
def vae_encode(self):
|
||||
try:
|
||||
@@ -794,9 +1240,28 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def sdxl_clip(self):
|
||||
try:
|
||||
self.inputs["clip"] = self.get_input_info_for(
|
||||
base_models["sdxl_clip"]
|
||||
)
|
||||
compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1)
|
||||
compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2)
|
||||
|
||||
check_compilation(compiled_clip, "Clip")
|
||||
check_compilation(compiled_clip, "Clip2")
|
||||
if self.return_mlir:
|
||||
return clip_mlir, clip_mlir2
|
||||
return compiled_clip, compiled_clip2
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def unet(self, use_large=False):
|
||||
try:
|
||||
model = "stencil_unet" if self.use_stencil is not None else "unet"
|
||||
stencil_count = 0
|
||||
for stencil in self.stencils:
|
||||
stencil_count += 1
|
||||
model = "stencil_unet" if stencil_count > 0 else "unet"
|
||||
compiled_unet = None
|
||||
unet_inputs = base_models[model]
|
||||
|
||||
@@ -805,7 +1270,7 @@ class SharkifyStableDiffusionModel:
|
||||
unet_inputs[self.base_model_id]
|
||||
)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(
|
||||
model, use_large=use_large
|
||||
model, use_large=use_large, base_model=self.base_model_id
|
||||
)
|
||||
else:
|
||||
for model_id in unet_inputs:
|
||||
@@ -816,7 +1281,7 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
try:
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(
|
||||
model, use_large=use_large
|
||||
model, use_large=use_large, base_model=model_id
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
@@ -855,7 +1320,10 @@ class SharkifyStableDiffusionModel:
|
||||
is_base_vae = self.base_vae
|
||||
if self.is_upscaler:
|
||||
self.base_vae = True
|
||||
compiled_vae, vae_mlir = self.get_vae()
|
||||
if self.is_sdxl:
|
||||
compiled_vae, vae_mlir = self.get_vae_sdxl()
|
||||
else:
|
||||
compiled_vae, vae_mlir = self.get_vae()
|
||||
self.base_vae = is_base_vae
|
||||
|
||||
check_compilation(compiled_vae, "Vae")
|
||||
@@ -865,18 +1333,18 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def controlnet(self, use_large=False):
|
||||
def controlnet(self, stencil_id, use_large=False):
|
||||
try:
|
||||
self.inputs["stencil_adaptor"] = self.get_input_info_for(
|
||||
base_models["stencil_adaptor"]
|
||||
self.inputs["stencil_adapter"] = self.get_input_info_for(
|
||||
base_models["stencil_adapter"]
|
||||
)
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
|
||||
use_large=use_large
|
||||
compiled_stencil_adapter, controlnet_mlir = self.get_control_net(
|
||||
stencil_id, use_large=use_large
|
||||
)
|
||||
|
||||
check_compilation(compiled_stencil_adaptor, "Stencil")
|
||||
check_compilation(compiled_stencil_adapter, "Stencil")
|
||||
if self.return_mlir:
|
||||
return controlnet_mlir
|
||||
return compiled_stencil_adaptor
|
||||
return compiled_stencil_adapter
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
@@ -123,8 +123,11 @@ def get_clip():
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_tokenizer():
|
||||
def get_tokenizer(subfolder="tokenizer", hf_model_id=None):
|
||||
if hf_model_id is not None:
|
||||
args.hf_model_id = hf_model_id
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.hf_model_id, subfolder="tokenizer"
|
||||
args.hf_model_id, subfolder=subfolder
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
|
||||
Text2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
|
||||
Text2ImageSDXLPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
|
||||
@@ -29,6 +29,10 @@ from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
|
||||
class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
@@ -84,13 +88,21 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
num_inference_steps,
|
||||
strength,
|
||||
dtype,
|
||||
resample_type,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre process image
|
||||
image = image.resize((width, height))
|
||||
# Pre-process image
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
image = image.resize((width, height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
@@ -146,7 +158,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
stencils,
|
||||
images,
|
||||
resample_type,
|
||||
control_mode,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -186,6 +201,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
|
||||
@@ -55,28 +55,47 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
controlnet_names: list[str],
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.controlnet = None
|
||||
self.controlnet_512 = None
|
||||
self.controlnet = [None] * len(controlnet_names)
|
||||
self.controlnet_512 = [None] * len(controlnet_names)
|
||||
self.controlnet_id = [str] * len(controlnet_names)
|
||||
self.controlnet_512_id = [str] * len(controlnet_names)
|
||||
self.controlnet_names = controlnet_names
|
||||
|
||||
def load_controlnet(self):
|
||||
if self.controlnet is not None:
|
||||
def load_controlnet(self, index, model_name):
|
||||
if model_name is None:
|
||||
return
|
||||
self.controlnet = self.sd_model.controlnet()
|
||||
|
||||
def unload_controlnet(self):
|
||||
del self.controlnet
|
||||
self.controlnet = None
|
||||
|
||||
def load_controlnet_512(self):
|
||||
if self.controlnet_512 is not None:
|
||||
if (
|
||||
self.controlnet[index] is not None
|
||||
and self.controlnet_id[index] is not None
|
||||
and self.controlnet_id[index] == model_name
|
||||
):
|
||||
return
|
||||
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
|
||||
self.controlnet_id[index] = model_name
|
||||
self.controlnet[index] = self.sd_model.controlnet(model_name)
|
||||
|
||||
def unload_controlnet_512(self):
|
||||
del self.controlnet_512
|
||||
self.controlnet_512 = None
|
||||
def unload_controlnet(self, index):
|
||||
del self.controlnet[index]
|
||||
self.controlnet_id[index] = None
|
||||
self.controlnet[index] = None
|
||||
|
||||
def load_controlnet_512(self, index, model_name):
|
||||
if (
|
||||
self.controlnet_512[index] is not None
|
||||
and self.controlnet_512_id[index] == model_name
|
||||
):
|
||||
return
|
||||
self.controlnet_512_id[index] = model_name
|
||||
self.controlnet_512[index] = self.sd_model.controlnet(
|
||||
model_name, use_large=True
|
||||
)
|
||||
|
||||
def unload_controlnet_512(self, index):
|
||||
del self.controlnet_512[index]
|
||||
self.controlnet_512_id[index] = None
|
||||
self.controlnet_512[index] = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -111,8 +130,9 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
total_timesteps,
|
||||
dtype,
|
||||
cpu_scheduling,
|
||||
controlnet_hint=None,
|
||||
stencil_hints=[None],
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
control_mode="Balanced", # Prompt, Balanced, or Controlnet
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
@@ -121,12 +141,18 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
assert control_mode in ["Prompt", "Balanced", "Controlnet"]
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
self.load_controlnet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
self.load_controlnet_512()
|
||||
|
||||
for i, name in enumerate(self.controlnet_names):
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_controlnet(i, name)
|
||||
else:
|
||||
self.load_controlnet_512(i, name)
|
||||
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype)
|
||||
@@ -149,55 +175,168 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
).to(dtype)
|
||||
else:
|
||||
latent_model_input_1 = latent_model_input
|
||||
if text_embeddings.shapes[1] <= self.model_max_length:
|
||||
control = self.controlnet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
control = self.controlnet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
# Multicontrolnet
|
||||
width = latent_model_input_1.shape[2]
|
||||
height = latent_model_input_1.shape[3]
|
||||
dtype = latent_model_input_1.dtype
|
||||
control_acc = (
|
||||
[torch.zeros((2, 320, height, width), dtype=dtype)] * 3
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 320, int(height / 2), int(width / 2)), dtype=dtype
|
||||
)
|
||||
]
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 640, int(height / 2), int(width / 2)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 2
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 640, int(height / 4), int(width / 4)), dtype=dtype
|
||||
)
|
||||
]
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 1280, int(height / 4), int(width / 4)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 2
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 1280, int(height / 8), int(width / 8)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 4
|
||||
)
|
||||
for i, controlnet_hint in enumerate(stencil_hints):
|
||||
if controlnet_hint is None:
|
||||
continue
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
control = self.controlnet[i](
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
*control_acc,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
control = self.controlnet_512[i](
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
*control_acc,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
control_acc = control[13:]
|
||||
control = control[:13]
|
||||
|
||||
timestep = timestep.detach().numpy()
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
dtype = latents.dtype
|
||||
if control_mode == "Balanced":
|
||||
control_scale = [
|
||||
torch.tensor(1.0, dtype=dtype) for _ in range(len(control))
|
||||
]
|
||||
elif control_mode == "Prompt":
|
||||
control_scale = [
|
||||
torch.tensor(0.825**x, dtype=dtype)
|
||||
for x in range(len(control))
|
||||
]
|
||||
elif control_mode == "Controlnet":
|
||||
control_scale = [
|
||||
torch.tensor(float(guidance_scale), dtype=dtype)
|
||||
for _ in range(len(control))
|
||||
]
|
||||
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
control_scale[0],
|
||||
control_scale[1],
|
||||
control_scale[2],
|
||||
control_scale[3],
|
||||
control_scale[4],
|
||||
control_scale[5],
|
||||
control_scale[6],
|
||||
control_scale[7],
|
||||
control_scale[8],
|
||||
control_scale[9],
|
||||
control_scale[10],
|
||||
control_scale[11],
|
||||
control_scale[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
print(self.unet_512)
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
control_scale[0],
|
||||
control_scale[1],
|
||||
control_scale[2],
|
||||
control_scale[3],
|
||||
control_scale[4],
|
||||
control_scale[5],
|
||||
control_scale[6],
|
||||
control_scale[7],
|
||||
control_scale[8],
|
||||
control_scale[9],
|
||||
control_scale[10],
|
||||
control_scale[11],
|
||||
control_scale[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
@@ -218,8 +357,9 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
self.unload_controlnet()
|
||||
self.unload_controlnet_512()
|
||||
for i in range(len(self.controlnet_names)):
|
||||
self.unload_controlnet(i)
|
||||
self.unload_controlnet_512(i)
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -245,13 +385,29 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
stencils,
|
||||
stencil_images,
|
||||
resample_type,
|
||||
control_mode,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# TODO: 1. Change `num_images_per_prompt`.
|
||||
controlnet_hint = controlnet_hint_conversion(
|
||||
image, use_stencil, height, width, dtype, num_images_per_prompt=1
|
||||
)
|
||||
# controlnet_hint = controlnet_hint_conversion(
|
||||
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
|
||||
# )
|
||||
stencil_hints = []
|
||||
for i, stencil in enumerate(stencils):
|
||||
image = stencil_images[i]
|
||||
stencil_hints.append(
|
||||
controlnet_hint_conversion(
|
||||
image,
|
||||
stencil,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
)
|
||||
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
@@ -299,7 +455,8 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
total_timesteps=final_timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
controlnet_hint=controlnet_hint,
|
||||
control_mode=control_mode,
|
||||
stencil_hints=stencil_hints,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
|
||||
@@ -18,7 +18,10 @@ from diffusers import (
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from typing import Union
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
is_fp32_vae: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.is_fp32_vae = is_fp32_vae
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype
|
||||
):
|
||||
add_time_ids = list(
|
||||
original_size + crops_coords_top_left + target_size
|
||||
)
|
||||
|
||||
# self.unet.config.addition_time_embed_dim IS 256.
|
||||
# self.text_encoder_2.config.projection_dim IS 1280.
|
||||
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
|
||||
expected_add_embed_dim = 2816
|
||||
# self.unet.add_embedding.linear_1.in_features IS 2816.
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||
)
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents.
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings.
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt_sdxl(
|
||||
prompt=prompts,
|
||||
num_images_per_prompt=1,
|
||||
do_classifier_free_guidance=True,
|
||||
negative_prompt=neg_prompts,
|
||||
)
|
||||
|
||||
# Prepare timesteps.
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Prepare added time ids & embeddings.
|
||||
original_size = (height, width)
|
||||
target_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat(
|
||||
[negative_prompt_embeds, prompt_embeds], dim=0
|
||||
)
|
||||
add_text_embeds = torch.cat(
|
||||
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
|
||||
)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds
|
||||
add_text_embeds = add_text_embeds.to(dtype)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(dtype)
|
||||
prompt_embeds = prompt_embeds.to(dtype)
|
||||
add_time_ids = add_time_ids.to(dtype)
|
||||
|
||||
# Get Image latents.
|
||||
latents = self.produce_img_latents_sdxl(
|
||||
init_latents,
|
||||
timesteps,
|
||||
add_text_embeds,
|
||||
add_time_ids,
|
||||
prompt_embeds,
|
||||
cpu_scheduling,
|
||||
guidance_scale,
|
||||
dtype,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images.
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in range(0, latents.shape[0], batch_size):
|
||||
imgs = self.decode_latents_sdxl(
|
||||
latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
@@ -20,7 +20,10 @@ from diffusers import (
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae,
|
||||
@@ -33,6 +36,8 @@ from apps.stable_diffusion.src.utils import (
|
||||
end_profiling,
|
||||
)
|
||||
import sys
|
||||
import gc
|
||||
from typing import List, Optional
|
||||
|
||||
SD_STATE_IDLE = "idle"
|
||||
SD_STATE_CANCEL = "cancel"
|
||||
@@ -50,6 +55,7 @@ class StableDiffusionPipeline:
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
@@ -60,20 +66,23 @@ class StableDiffusionPipeline:
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
is_f32_vae: bool = False,
|
||||
):
|
||||
self.vae = None
|
||||
self.text_encoder = None
|
||||
self.text_encoder_2 = None
|
||||
self.unet = None
|
||||
self.unet_512 = None
|
||||
self.model_max_length = 77
|
||||
self.scheduler = scheduler
|
||||
# TODO: Implement using logging python utility.
|
||||
self.log = ""
|
||||
self.status = SD_STATE_IDLE
|
||||
self.sd_model = sd_model
|
||||
self.scheduler = scheduler
|
||||
self.import_mlir = import_mlir
|
||||
self.use_lora = use_lora
|
||||
self.ondemand = ondemand
|
||||
self.is_f32_vae = is_f32_vae
|
||||
# TODO: Find a better workaround for fetching base_model_id early
|
||||
# enough for CLIPTokenizer.
|
||||
try:
|
||||
@@ -106,6 +115,34 @@ class StableDiffusionPipeline:
|
||||
del self.text_encoder
|
||||
self.text_encoder = None
|
||||
|
||||
def load_clip_sdxl(self):
|
||||
if self.text_encoder and self.text_encoder_2:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
if not self.import_mlir:
|
||||
print(
|
||||
"Warning: LoRA provided but import_mlir not specified. "
|
||||
"Importing MLIR anyways."
|
||||
)
|
||||
self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip()
|
||||
else:
|
||||
try:
|
||||
# TODO: Fix this for SDXL
|
||||
self.text_encoder = get_clip()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
) = self.sd_model.sdxl_clip()
|
||||
|
||||
def unload_clip_sdxl(self):
|
||||
del self.text_encoder, self.text_encoder_2
|
||||
self.text_encoder = None
|
||||
self.text_encoder_2 = None
|
||||
|
||||
def load_unet(self):
|
||||
if self.unet is not None:
|
||||
return
|
||||
@@ -159,6 +196,182 @@ class StableDiffusionPipeline:
|
||||
def unload_vae(self):
|
||||
del self.vae
|
||||
self.vae = None
|
||||
gc.collect()
|
||||
|
||||
def encode_prompt_sdxl(
|
||||
self,
|
||||
prompt: str,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[str] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
hf_model_id: Optional[
|
||||
str
|
||||
] = "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
self.tokenizer_2 = get_tokenizer("tokenizer_2", hf_model_id)
|
||||
self.load_clip_sdxl()
|
||||
tokenizers = (
|
||||
[self.tokenizer, self.tokenizer_2]
|
||||
if self.tokenizer is not None
|
||||
else [self.tokenizer_2]
|
||||
)
|
||||
text_encoders = (
|
||||
[self.text_encoder, self.text_encoder_2]
|
||||
if self.text_encoder is not None
|
||||
else [self.text_encoder_2]
|
||||
)
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
prompts = [prompt, prompt]
|
||||
for prompt, tokenizer, text_encoder in zip(
|
||||
prompts, tokenizers, text_encoders
|
||||
):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(
|
||||
prompt, padding="longest", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
||||
-1
|
||||
] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(
|
||||
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
print(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_encoder_output = text_encoder("forward", (text_input_ids,))
|
||||
prompt_embeds = torch.from_numpy(text_encoder_output[0])
|
||||
pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1])
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
zero_out_negative_prompt = (
|
||||
negative_prompt is None
|
||||
and self.config.force_zeros_for_empty_prompt
|
||||
)
|
||||
if (
|
||||
do_classifier_free_guidance
|
||||
and negative_prompt_embeds is None
|
||||
and zero_out_negative_prompt
|
||||
):
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(
|
||||
pooled_prompt_embeds
|
||||
)
|
||||
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt_2 = negative_prompt
|
||||
|
||||
uncond_tokens: List[str]
|
||||
if prompt is not None and type(prompt) is not type(
|
||||
negative_prompt
|
||||
):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
for negative_prompt, tokenizer, text_encoder in zip(
|
||||
uncond_tokens, tokenizers, text_encoders
|
||||
):
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = tokenizer(
|
||||
negative_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_encoder_output = text_encoder(
|
||||
"forward", (uncond_input.input_ids,)
|
||||
)
|
||||
negative_prompt_embeds = torch.from_numpy(
|
||||
text_encoder_output[0]
|
||||
)
|
||||
negative_pooled_prompt_embeds = torch.from_numpy(
|
||||
text_encoder_output[1]
|
||||
)
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
negative_prompt_embeds = torch.concat(
|
||||
negative_prompt_embeds_list, dim=-1
|
||||
)
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_clip_sdxl()
|
||||
gc.collect()
|
||||
|
||||
# TODO: Look into dtype for text_encoder_2!
|
||||
prompt_embeds = prompt_embeds.to(dtype=torch.float16)
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(
|
||||
bs_embed * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
||||
1, num_images_per_prompt, 1
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(
|
||||
1, num_images_per_prompt
|
||||
).view(bs_embed * num_images_per_prompt, -1)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
|
||||
1, num_images_per_prompt
|
||||
).view(bs_embed * num_images_per_prompt, -1)
|
||||
|
||||
return (
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
def encode_prompts(self, prompts, neg_prompts, max_length):
|
||||
# Tokenize text and get embeddings
|
||||
@@ -186,6 +399,7 @@ class StableDiffusionPipeline:
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_clip()
|
||||
gc.collect()
|
||||
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
|
||||
|
||||
return text_embeddings
|
||||
@@ -298,6 +512,8 @@ class StableDiffusionPipeline:
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
gc.collect()
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -306,6 +522,96 @@ class StableDiffusionPipeline:
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def produce_img_latents_sdxl(
|
||||
self,
|
||||
latents,
|
||||
total_timesteps,
|
||||
add_text_embeds,
|
||||
add_time_ids,
|
||||
prompt_embeds,
|
||||
cpu_scheduling,
|
||||
guidance_scale,
|
||||
dtype,
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
# return None
|
||||
self.status = SD_STATE_IDLE
|
||||
step_time_sum = 0
|
||||
extra_step_kwargs = {"generator": None}
|
||||
self.load_unet()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
if isinstance(latents, np.ndarray):
|
||||
latents = torch.tensor(latents)
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
if mask is not None and masked_image_latents is not None:
|
||||
latent_model_input = torch.cat(
|
||||
[
|
||||
torch.from_numpy(np.asarray(latent_model_input)),
|
||||
mask,
|
||||
masked_image_latents,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype)
|
||||
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
add_text_embeds,
|
||||
add_time_ids,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=True,
|
||||
)
|
||||
if not isinstance(latents, torch.Tensor):
|
||||
latents = torch.from_numpy(latents).to("cpu")
|
||||
noise_pred = torch.from_numpy(noise_pred).to("cpu")
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
||||
)[0]
|
||||
latents = latents.detach().numpy()
|
||||
noise_pred = noise_pred.detach().numpy()
|
||||
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.status == SD_STATE_CANCEL:
|
||||
break
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
gc.collect()
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
return latents
|
||||
|
||||
def decode_latents_sdxl(self, latents, is_fp32_vae):
|
||||
# latents are in unet dtype here so switch if we want to use fp32
|
||||
if is_fp32_vae:
|
||||
print("Casting latents to float32 for VAE")
|
||||
latents = latents.to(torch.float32)
|
||||
images = self.vae("forward", (latents,))
|
||||
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
|
||||
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -338,7 +644,8 @@ class StableDiffusionPipeline:
|
||||
ondemand: bool,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
debug: bool = False,
|
||||
use_stencil: str = None,
|
||||
stencils: list[str] = [],
|
||||
# stencil_images: list[Image] = []
|
||||
use_lora: str = "",
|
||||
ddpm_scheduler: DDPMScheduler = None,
|
||||
use_quantize=None,
|
||||
@@ -355,6 +662,7 @@ class StableDiffusionPipeline:
|
||||
"OutpaintPipeline",
|
||||
]
|
||||
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
|
||||
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
|
||||
|
||||
sd_model = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
@@ -371,7 +679,8 @@ class StableDiffusionPipeline:
|
||||
debug=debug,
|
||||
is_inpaint=is_inpaint,
|
||||
is_upscaler=is_upscaler,
|
||||
use_stencil=use_stencil,
|
||||
is_sdxl=is_sdxl,
|
||||
stencils=stencils,
|
||||
use_lora=use_lora,
|
||||
use_quantize=use_quantize,
|
||||
)
|
||||
@@ -386,6 +695,21 @@ class StableDiffusionPipeline:
|
||||
ondemand,
|
||||
)
|
||||
|
||||
if cls.__name__ == "StencilPipeline":
|
||||
return cls(
|
||||
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
|
||||
)
|
||||
if cls.__name__ == "Text2ImageSDXLPipeline":
|
||||
is_fp32_vae = True if "16" not in custom_vae else False
|
||||
return cls(
|
||||
scheduler,
|
||||
sd_model,
|
||||
import_mlir,
|
||||
use_lora,
|
||||
ondemand,
|
||||
is_fp32_vae,
|
||||
)
|
||||
|
||||
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
|
||||
# #####################################################
|
||||
@@ -498,9 +822,10 @@ class StableDiffusionPipeline:
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_clip()
|
||||
gc.collect()
|
||||
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
|
||||
|
||||
return text_embeddings.numpy()
|
||||
return text_embeddings.numpy().astype(np.float16)
|
||||
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from diffusers import (
|
||||
LCMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDPMScheduler,
|
||||
@@ -15,9 +16,21 @@ from diffusers import (
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
|
||||
|
||||
def get_schedulers(model_id):
|
||||
# TODO: Robust scheduler setup on pipeline creation -- if we don't
|
||||
# set batch_size here, the SHARK schedulers will
|
||||
# compile with batch size = 1 regardless of whether the model
|
||||
# outputs latents of a larger batch size, e.g. SDXL.
|
||||
# However, obviously, searching for whether the base model ID
|
||||
# contains "xl" is not very robust.
|
||||
|
||||
batch_size = 2 if "xl" in model_id.lower() else 1
|
||||
|
||||
schedulers = dict()
|
||||
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
model_id,
|
||||
@@ -39,6 +52,10 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
@@ -84,6 +101,12 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"SharkEulerAncestralDiscrete"
|
||||
] = SharkEulerAncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverSinglestep"
|
||||
] = DPMSolverSinglestepScheduler.from_pretrained(
|
||||
@@ -100,5 +123,6 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["SharkEulerDiscrete"].compile()
|
||||
schedulers["SharkEulerDiscrete"].compile(batch_size)
|
||||
schedulers["SharkEulerAncestralDiscrete"].compile(batch_size)
|
||||
return schedulers
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_shark_model,
|
||||
args,
|
||||
)
|
||||
import torch
|
||||
|
||||
|
||||
class SharkEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
beta_start,
|
||||
beta_end,
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
timestep_spacing,
|
||||
steps_offset,
|
||||
)
|
||||
# TODO: make it dynamic so we dont have to worry about batch size
|
||||
self.batch_size = None
|
||||
self.init_input_shape = None
|
||||
|
||||
def compile(self, batch_size=1):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
device = args.device.split(":", 1)[0].strip()
|
||||
self.batch_size = batch_size
|
||||
|
||||
model_input = {
|
||||
"eulera": {
|
||||
"output": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"latent": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"sigma_from": torch.tensor(1).to(torch.float32),
|
||||
"sigma_to": torch.tensor(1).to(torch.float32),
|
||||
"noise": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
example_latent = model_input["eulera"]["latent"]
|
||||
example_output = model_input["eulera"]["output"]
|
||||
example_noise = model_input["eulera"]["noise"]
|
||||
if args.precision == "fp16":
|
||||
example_latent = example_latent.half()
|
||||
example_output = example_output.half()
|
||||
example_noise = example_noise.half()
|
||||
example_sigma = model_input["eulera"]["sigma"]
|
||||
example_sigma_from = model_input["eulera"]["sigma_from"]
|
||||
example_sigma_to = model_input["eulera"]["sigma_to"]
|
||||
|
||||
class ScalingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepEpsilonModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, noise_pred, latent, sigma, sigma_from, sigma_to, noise
|
||||
):
|
||||
sigma_up = (
|
||||
sigma_to**2
|
||||
* (sigma_from**2 - sigma_to**2)
|
||||
/ sigma_from**2
|
||||
) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
dt = sigma_down - sigma
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
prev_sample = latent + derivative * dt
|
||||
return prev_sample + noise * sigma_up
|
||||
|
||||
class SchedulerStepVPredictionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, noise_pred, sigma, sigma_from, sigma_to, latent, noise
|
||||
):
|
||||
sigma_up = (
|
||||
sigma_to**2
|
||||
* (sigma_from**2 - sigma_to**2)
|
||||
/ sigma_from**2
|
||||
) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
dt = sigma_down - sigma
|
||||
pred_original_sample = noise_pred * (
|
||||
-sigma / (sigma**2 + 1) ** 0.5
|
||||
) + (latent / (sigma**2 + 1))
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
prev_sample = latent + derivative * dt
|
||||
return prev_sample + noise * sigma_up
|
||||
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model, _ = compile_through_fx(
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
extended_model_name=f"euler_a_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
pred_type_model_dict = {
|
||||
"epsilon": SchedulerStepEpsilonModel(),
|
||||
"v_prediction": SchedulerStepVPredictionModel(),
|
||||
}
|
||||
step_model = pred_type_model_dict[self.config.prediction_type]
|
||||
self.step_model, _ = compile_through_fx(
|
||||
step_model,
|
||||
(
|
||||
example_output,
|
||||
example_latent,
|
||||
example_sigma,
|
||||
example_sigma_from,
|
||||
example_sigma_to,
|
||||
example_noise,
|
||||
),
|
||||
extended_model_name=f"euler_a_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
if args.import_mlir:
|
||||
_import(self)
|
||||
|
||||
else:
|
||||
try:
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_a_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_a_step_"
|
||||
+ self.config.prediction_type
|
||||
+ args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"failed to download model, falling back and using import_mlir"
|
||||
)
|
||||
args.import_mlir = True
|
||||
_import(self)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
sigma = self.sigmas[self.step_index]
|
||||
return self.scaling_model(
|
||||
"forward",
|
||||
(
|
||||
sample,
|
||||
sigma,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(
|
||||
self,
|
||||
noise_pred,
|
||||
timestep,
|
||||
latent,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
step_inputs = []
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
sigma_from = self.sigmas[self.step_index]
|
||||
sigma_to = self.sigmas[self.step_index + 1]
|
||||
noise = randn_tensor(
|
||||
torch.Size(noise_pred.shape),
|
||||
dtype=torch.float16,
|
||||
device="cpu",
|
||||
generator=generator,
|
||||
)
|
||||
step_inputs = [
|
||||
noise_pred,
|
||||
latent,
|
||||
sigma,
|
||||
sigma_from,
|
||||
sigma_to,
|
||||
noise,
|
||||
]
|
||||
# TODO: deal with dynamic inputs in turbine flow.
|
||||
# update step index since we're done with the variable and will return with compiled module output.
|
||||
self._step_index += 1
|
||||
|
||||
if noise_pred.shape[0] < self.batch_size:
|
||||
for i in [0, 1, 5]:
|
||||
try:
|
||||
step_inputs[i] = torch.tensor(step_inputs[i])
|
||||
except:
|
||||
step_inputs[i] = torch.tensor(step_inputs[i].to_host())
|
||||
step_inputs[i] = torch.cat(
|
||||
(step_inputs[i], step_inputs[i]), axis=0
|
||||
)
|
||||
return self.step_model(
|
||||
"forward",
|
||||
tuple(step_inputs),
|
||||
send_to_host=True,
|
||||
)
|
||||
|
||||
return self.step_model(
|
||||
"forward",
|
||||
tuple(step_inputs),
|
||||
send_to_host=False,
|
||||
)
|
||||
@@ -2,12 +2,9 @@ import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
@@ -27,6 +24,13 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
interpolation_type: str = "linear",
|
||||
use_karras_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
timestep_type: str = "discrete",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
@@ -35,20 +39,29 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
interpolation_type,
|
||||
use_karras_sigmas,
|
||||
sigma_min,
|
||||
sigma_max,
|
||||
timestep_spacing,
|
||||
timestep_type,
|
||||
steps_offset,
|
||||
)
|
||||
# TODO: make it dynamic so we dont have to worry about batch size
|
||||
self.batch_size = None
|
||||
|
||||
def compile(self):
|
||||
def compile(self, batch_size=1):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
BATCH_SIZE = args.batch_size
|
||||
device = args.device.split(":", 1)[0].strip()
|
||||
self.batch_size = batch_size
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
"latent": torch.randn(
|
||||
BATCH_SIZE, 4, args.height // 8, args.width // 8
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"output": torch.randn(
|
||||
BATCH_SIZE, 4, args.height // 8, args.width // 8
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"dt": torch.tensor(1).to(torch.float32),
|
||||
@@ -70,12 +83,32 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepModel(torch.nn.Module):
|
||||
class SchedulerStepEpsilonModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma_hat, latent, dt):
|
||||
pred_original_sample = latent - sigma_hat * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma_hat
|
||||
return latent + derivative * dt
|
||||
|
||||
class SchedulerStepSampleModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma_hat, latent, dt):
|
||||
pred_original_sample = noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma_hat
|
||||
return latent + derivative * dt
|
||||
|
||||
class SchedulerStepVPredictionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma, latent, dt):
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
pred_original_sample = noise_pred * (
|
||||
-sigma / (sigma**2 + 1) ** 0.5
|
||||
) + (latent / (sigma**2 + 1))
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
return latent + derivative * dt
|
||||
|
||||
@@ -84,25 +117,28 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model, _ = compile_through_fx(
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
|
||||
extended_model_name=f"euler_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
step_model = SchedulerStepModel()
|
||||
pred_type_model_dict = {
|
||||
"epsilon": SchedulerStepEpsilonModel(),
|
||||
"v_prediction": SchedulerStepVPredictionModel(),
|
||||
"sample": SchedulerStepSampleModel(),
|
||||
"original_sample": SchedulerStepSampleModel(),
|
||||
}
|
||||
step_model = pred_type_model_dict[self.config.prediction_type]
|
||||
self.step_model, _ = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
|
||||
extended_model_name=f"euler_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
@@ -112,6 +148,11 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
|
||||
else:
|
||||
try:
|
||||
step_model_type = (
|
||||
"sample"
|
||||
if "sample" in self.config.prediction_type
|
||||
else self.config.prediction_type
|
||||
)
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_scale_model_input_" + args.precision,
|
||||
@@ -119,7 +160,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_step_" + args.precision,
|
||||
"euler_step_" + step_model_type + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
@@ -141,15 +182,57 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(self, noise_pred, timestep, latent):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
dt = self.sigmas[step_index + 1] - sigma
|
||||
def step(
|
||||
self,
|
||||
noise_pred,
|
||||
timestep,
|
||||
latent,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
gamma = (
|
||||
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
|
||||
if s_tmin <= sigma <= s_tmax
|
||||
else 0.0
|
||||
)
|
||||
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
noise_pred = (
|
||||
torch.from_numpy(noise_pred)
|
||||
if isinstance(noise_pred, np.ndarray)
|
||||
else noise_pred
|
||||
)
|
||||
|
||||
if gamma > 0:
|
||||
noise = randn_tensor(
|
||||
torch.Size(noise_pred.shape),
|
||||
dtype=torch.float16,
|
||||
device="cpu",
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
eps = noise * s_noise
|
||||
latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
if self.config.prediction_type == "v_prediction":
|
||||
sigma_hat = sigma
|
||||
|
||||
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
||||
return self.step_model(
|
||||
"forward",
|
||||
(
|
||||
noise_pred,
|
||||
sigma,
|
||||
sigma_hat,
|
||||
latent,
|
||||
dt,
|
||||
),
|
||||
|
||||
@@ -41,3 +41,8 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
resize_stencil,
|
||||
_compile_module,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint
|
||||
from apps.stable_diffusion.src.utils.resamplers import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
42
apps/stable_diffusion/src/utils/civitai.py
Normal file
42
apps/stable_diffusion/src/utils/civitai.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import re
|
||||
import requests
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_civitai_checkpoint(url: str):
|
||||
with requests.get(url, allow_redirects=True, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# civitai api returns the filename in the content disposition
|
||||
base_filename = re.findall(
|
||||
'"([^"]*)"', response.headers["Content-Disposition"]
|
||||
)[0]
|
||||
destination_path = (
|
||||
Path.cwd() / (args.ckpt_dir or "models") / base_filename
|
||||
)
|
||||
|
||||
# we don't have this model downloaded yet
|
||||
if not destination_path.is_file():
|
||||
print(
|
||||
f"downloading civitai model from {url} to {destination_path}"
|
||||
)
|
||||
|
||||
size = int(response.headers["content-length"], 0)
|
||||
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=65536):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
# we already have this model downloaded
|
||||
else:
|
||||
print(f"civitai model already downloaded to {destination_path}")
|
||||
|
||||
response.close()
|
||||
return destination_path.as_posix()
|
||||
12
apps/stable_diffusion/src/utils/resamplers.py
Normal file
12
apps/stable_diffusion/src/utils/resamplers.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import PIL.Image as Image
|
||||
|
||||
resamplers = {
|
||||
"Lanczos": Image.Resampling.LANCZOS,
|
||||
"Nearest Neighbor": Image.Resampling.NEAREST,
|
||||
"Bilinear": Image.Resampling.BILINEAR,
|
||||
"Bicubic": Image.Resampling.BICUBIC,
|
||||
"Hamming": Image.Resampling.HAMMING,
|
||||
"Box": Image.Resampling.BOX,
|
||||
}
|
||||
|
||||
resampler_list = resamplers.keys()
|
||||
@@ -8,6 +8,15 @@
|
||||
"dtype":"i64"
|
||||
}
|
||||
},
|
||||
"sdxl_clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"1*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
@@ -179,9 +188,95 @@
|
||||
"shape": [2],
|
||||
"dtype": "i64"
|
||||
}
|
||||
},
|
||||
"stabilityai/sdxl-turbo": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"prompt_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
2048
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"text_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
1280
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"time_ids": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
6
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"prompt_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
2048
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"text_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
1280
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"time_ids": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
6
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
}
|
||||
},
|
||||
"stencil_adaptor": {
|
||||
"stencil_adapter": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
@@ -208,6 +303,58 @@
|
||||
"controlnet_hint": {
|
||||
"shape": [1, 3, "8*height", "8*width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc1": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc2": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc3": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc4": {
|
||||
"shape": [2, 320, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc5": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc6": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc7": {
|
||||
"shape": [2, 640, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc8": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc9": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc10": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc11": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc12": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc13": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stencil_unet": {
|
||||
@@ -290,7 +437,59 @@
|
||||
"control13": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale1": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale2": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale3": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale4": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale5": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale6": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale7": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale8": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale9": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale10": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale11": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale12": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale13": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,12 +11,12 @@
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,7 @@
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
},
|
||||
@@ -37,7 +37,7 @@
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -45,12 +45,12 @@
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -59,24 +59,28 @@
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
"--iree-opt-data-tiling=False"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
"--iree-opt-data-tiling=False"
|
||||
]
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
"--iree-opt-data-tiling=False"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
"--iree-opt-data-tiling=False"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
[["A high tech solarpunk utopia in the Amazon rainforest"],
|
||||
["Astrophotography, the shark nebula, nebula with a tiny shark-like cloud in the middle in the middle, hubble telescope, vivid colors"],
|
||||
["A pikachu fine dining with a view to the Eiffel Tower"],
|
||||
["A mecha robot in a favela in expressionist style"],
|
||||
["an insect robot preparing a delicious meal"],
|
||||
|
||||
@@ -158,9 +158,9 @@ def load_lower_configs(base_model_id=None):
|
||||
f"{spec}.json"
|
||||
)
|
||||
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading lowering config file from ", lowering_config_dir)
|
||||
full_gs_url = config_bucket + config_name
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
return lowering_config_dir
|
||||
|
||||
@@ -203,8 +203,8 @@ def dump_after_mlir(input_mlir, use_winograd):
|
||||
if use_winograd:
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
|
||||
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
|
||||
"iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
@@ -212,8 +212,8 @@ def dump_after_mlir(input_mlir, use_winograd):
|
||||
else:
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
|
||||
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
)
|
||||
@@ -281,13 +281,9 @@ def sd_model_annotation(mlir_model, model_name, base_model_id=None):
|
||||
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
|
||||
use_winograd = True
|
||||
winograd_config_dir = load_winograd_configs()
|
||||
winograd_model = annotate_with_winograd(
|
||||
tuned_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
winograd_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
else:
|
||||
tuned_model = mlir_model
|
||||
else:
|
||||
|
||||
@@ -2,6 +2,8 @@ import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from apps.stable_diffusion.src.utils.resamplers import resampler_list
|
||||
|
||||
|
||||
def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
@@ -83,7 +85,7 @@ p.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 769, 8),
|
||||
choices=range(128, 1025, 8),
|
||||
help="The height of the output image.",
|
||||
)
|
||||
|
||||
@@ -91,7 +93,7 @@ p.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 769, 8),
|
||||
choices=range(128, 1025, 8),
|
||||
help="The width of the output image.",
|
||||
)
|
||||
|
||||
@@ -132,6 +134,47 @@ p.add_argument(
|
||||
"img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_hiresfix",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Use Hires Fix to do higher resolution images, while trying to "
|
||||
"avoid the issues that come with it. This is accomplished by first "
|
||||
"generating an image using txt2img, then running it through img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_height",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The height of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_width",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The width of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_strength",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="The denoising strength to apply for the Hires Fix.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--resample_type",
|
||||
type=str,
|
||||
default="Nearest Neighbor",
|
||||
choices=resampler_list,
|
||||
help="The resample type to use when resizing an image before being run "
|
||||
"through stable diffusion.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Stable Diffusion Training Params
|
||||
##############################################################################
|
||||
@@ -202,28 +245,30 @@ p.add_argument(
|
||||
"--left",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend left for outpainting.",
|
||||
help="If extend left for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--right",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend right for outpainting.",
|
||||
help="If extend right for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--up",
|
||||
"--top",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend top for outpainting.",
|
||||
help="If extend top for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--down",
|
||||
"--bottom",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend bottom for outpainting.",
|
||||
help="If extend bottom for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -255,7 +300,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Imports the model from torch module to shark_module otherwise "
|
||||
"downloads the model from shark_tank.",
|
||||
@@ -278,7 +323,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=True,
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available.",
|
||||
)
|
||||
@@ -371,10 +416,17 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--use_stencil",
|
||||
choices=["canny", "openpose", "scribble"],
|
||||
choices=["canny", "openpose", "scribble", "zoedepth"],
|
||||
help="Enable the stencil feature.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--control_mode",
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
default="Balanced",
|
||||
help="How Controlnet injection should be prioritized.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_lora",
|
||||
type=str,
|
||||
@@ -407,6 +459,21 @@ p.add_argument(
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device_allocator_heap_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify heap key for device caching allocator."
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--autogen",
|
||||
type=bool,
|
||||
default="False",
|
||||
help="Only used for a gradio workaround.",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
@@ -519,6 +586,14 @@ p.add_argument(
|
||||
"in shark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--compile_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to toggle debug assert/verify flags for imported IR in the"
|
||||
"iree-compiler. Default to false.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_constant_folding",
|
||||
default=True,
|
||||
@@ -526,6 +601,13 @@ p.add_argument(
|
||||
help="Controls constant folding in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--data_tiling",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Controls data tiling in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Web UI flags
|
||||
##############################################################################
|
||||
@@ -574,6 +656,25 @@ p.add_argument(
|
||||
help="Flag for enabling rest API.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--api_accept_origin",
|
||||
action="append",
|
||||
type=str,
|
||||
help="An origin to be accepted by the REST api for Cross Origin"
|
||||
"Resource Sharing (CORS). Use multiple times for multiple origins, "
|
||||
'or use --api_accept_origin="*" to accept all origins. If no origins '
|
||||
"are set no CORS headers will be returned by the api. Use, for "
|
||||
"instance, if you need to access the REST api from Javascript running "
|
||||
"in a web browser.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling debugging log in WebUI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery",
|
||||
default=True,
|
||||
@@ -651,6 +752,18 @@ p.add_argument(
|
||||
help="Specifies whether the docuchat's web version is running or not.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# rocm Flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree_rocm_target_chip",
|
||||
type=str,
|
||||
default="",
|
||||
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
|
||||
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
if args.import_debug:
|
||||
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector
|
||||
from apps.stable_diffusion.src.utils.stencils.zoe import ZoeDetector
|
||||
|
||||
@@ -1,14 +1,46 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
import torchvision
|
||||
import time
|
||||
from apps.stable_diffusion.src.utils.stencils import (
|
||||
CannyDetector,
|
||||
OpenposeDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
|
||||
stencil = {}
|
||||
|
||||
|
||||
def save_img(img):
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
|
||||
subdir = Path(
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
os.makedirs(subdir, exist_ok=True)
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(
|
||||
os.path.join(
|
||||
subdir, "controlnet_" + str(int(time.time())) + ".png"
|
||||
)
|
||||
)
|
||||
elif isinstance(img, np.ndarray):
|
||||
img = Image.fromarray(img)
|
||||
img.save(os.path.join(subdir, str(int(time.time())) + ".png"))
|
||||
else:
|
||||
converter = torchvision.transforms.ToPILImage()
|
||||
for i in img:
|
||||
converter(i).save(
|
||||
os.path.join(subdir, str(int(time.time())) + ".png")
|
||||
)
|
||||
|
||||
|
||||
def HWC3(x):
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
@@ -47,10 +79,12 @@ def controlnet_hint_shaping(
|
||||
)
|
||||
return controlnet_hint
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width}),"
|
||||
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
|
||||
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
|
||||
return controlnet_hint_shaping(
|
||||
Image.fromarray(controlnet_hint.detach().numpy()),
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt,
|
||||
)
|
||||
elif isinstance(controlnet_hint, np.ndarray):
|
||||
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
|
||||
@@ -77,29 +111,36 @@ def controlnet_hint_shaping(
|
||||
) # b h w c -> b c h w
|
||||
return controlnet_hint
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptble shape of `stencil` are any of ({width}, {channels}), "
|
||||
+ f"({height}, {width}, {channels}), "
|
||||
+ f"(1, {height}, {width}, {channels}) or "
|
||||
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
|
||||
)
|
||||
elif isinstance(controlnet_hint, Image.Image):
|
||||
if controlnet_hint.size == (width, height):
|
||||
controlnet_hint = controlnet_hint.convert(
|
||||
"RGB"
|
||||
) # make sure 3 channel RGB format
|
||||
controlnet_hint = np.array(controlnet_hint) # to numpy
|
||||
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
|
||||
return controlnet_hint_shaping(
|
||||
controlnet_hint, height, width, num_images_per_prompt
|
||||
Image.fromarray(controlnet_hint),
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt,
|
||||
)
|
||||
|
||||
elif isinstance(controlnet_hint, Image.Image):
|
||||
controlnet_hint = controlnet_hint.convert(
|
||||
"RGB"
|
||||
) # make sure 3 channel RGB format
|
||||
if controlnet_hint.size == (width, height):
|
||||
controlnet_hint = np.array(controlnet_hint).astype(
|
||||
np.float16
|
||||
) # to numpy
|
||||
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}"
|
||||
)
|
||||
(hint_w, hint_h) = controlnet_hint.size
|
||||
left = int((hint_w - width) / 2)
|
||||
right = left + height
|
||||
controlnet_hint = controlnet_hint.crop((left, 0, right, hint_h))
|
||||
controlnet_hint = controlnet_hint.resize((width, height))
|
||||
return controlnet_hint_shaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
|
||||
f"Acceptible controlnet input types are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -109,14 +150,23 @@ def controlnet_hint_conversion(
|
||||
controlnet_hint = None
|
||||
match use_stencil:
|
||||
case "canny":
|
||||
print("Detecting edge with canny")
|
||||
print(
|
||||
"Converting controlnet hint to edge detection mask with canny preprocessor."
|
||||
)
|
||||
controlnet_hint = hint_canny(image)
|
||||
case "openpose":
|
||||
print("Detecting human pose")
|
||||
print(
|
||||
"Detecting human pose in controlnet hint with openpose preprocessor."
|
||||
)
|
||||
controlnet_hint = hint_openpose(image)
|
||||
case "scribble":
|
||||
print("Working with scribble")
|
||||
print("Using your scribble as a controlnet hint.")
|
||||
controlnet_hint = hint_scribble(image)
|
||||
case "zoedepth":
|
||||
print(
|
||||
"Converting controlnet hint to a depth mapping with ZoeDepth."
|
||||
)
|
||||
controlnet_hint = hint_zoedepth(image)
|
||||
case _:
|
||||
return None
|
||||
controlnet_hint = controlnet_hint_shaping(
|
||||
@@ -127,7 +177,7 @@ def controlnet_hint_conversion(
|
||||
|
||||
stencil_to_model_id_map = {
|
||||
"canny": "lllyasviel/control_v11p_sd15_canny",
|
||||
"depth": "lllyasviel/control_v11p_sd15_depth",
|
||||
"zoedepth": "lllyasviel/control_v11f1p_sd15_depth",
|
||||
"hed": "lllyasviel/sd-controlnet-hed",
|
||||
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
|
||||
"normal": "lllyasviel/control_v11p_sd15_normalbae",
|
||||
@@ -157,6 +207,7 @@ def hint_canny(
|
||||
detected_map = stencil["canny"](
|
||||
input_image, low_threshold, high_threshold
|
||||
)
|
||||
save_img(detected_map)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
@@ -172,6 +223,7 @@ def hint_openpose(
|
||||
stencil["openpose"] = OpenposeDetector()
|
||||
|
||||
detected_map, _ = stencil["openpose"](input_image)
|
||||
save_img(detected_map)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
@@ -183,4 +235,19 @@ def hint_scribble(image: Image.Image):
|
||||
|
||||
detected_map = np.zeros_like(input_image, dtype=np.uint8)
|
||||
detected_map[np.min(input_image, axis=2) < 127] = 255
|
||||
save_img(detected_map)
|
||||
return detected_map
|
||||
|
||||
|
||||
# Stencil 4. Depth (Only Zoe Preprocessing)
|
||||
def hint_zoedepth(image: Image.Image):
|
||||
with torch.no_grad():
|
||||
input_image = np.array(image)
|
||||
|
||||
if not "depth" in stencil:
|
||||
stencil["depth"] = ZoeDetector()
|
||||
|
||||
detected_map = stencil["depth"](input_image)
|
||||
save_img(detected_map)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
64
apps/stable_diffusion/src/utils/stencils/zoe/__init__.py
Normal file
64
apps/stable_diffusion/src/utils/stencils/zoe/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import requests
|
||||
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
remote_model_path = (
|
||||
"https://huggingface.co/lllyasviel/Annotators/resolve/main/ZoeD_M12_N.pt"
|
||||
)
|
||||
|
||||
|
||||
class ZoeDetector:
|
||||
def __init__(self):
|
||||
cwd = Path.cwd()
|
||||
ckpt_path = Path(cwd, "stencil_annotator")
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
modelpath = ckpt_path / "ZoeD_M12_N.pt"
|
||||
|
||||
with requests.get(remote_model_path, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(modelpath, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
model = torch.hub.load(
|
||||
"monorimet/ZoeDepth:torch_update",
|
||||
"ZoeD_N",
|
||||
pretrained=False,
|
||||
force_reload=False,
|
||||
)
|
||||
|
||||
# Hack to fix the ZoeDepth import issue
|
||||
model_keys = model.state_dict().keys()
|
||||
loaded_dict = torch.load(modelpath, map_location=model.device)["model"]
|
||||
loaded_keys = loaded_dict.keys()
|
||||
for key in loaded_keys - model_keys:
|
||||
loaded_dict.pop(key)
|
||||
|
||||
model.load_state_dict(loaded_dict)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
def __call__(self, input_image):
|
||||
assert input_image.ndim == 3
|
||||
image_depth = input_image
|
||||
with torch.no_grad():
|
||||
image_depth = torch.from_numpy(image_depth).float()
|
||||
image_depth = image_depth / 255.0
|
||||
image_depth = rearrange(image_depth, "h w c -> 1 c h w")
|
||||
depth = self.model.infer(image_depth)
|
||||
|
||||
depth = depth[0, 0].cpu().numpy()
|
||||
|
||||
vmin = np.percentile(depth, 2)
|
||||
vmax = np.percentile(depth, 85)
|
||||
|
||||
depth -= vmin
|
||||
depth /= vmax - vmin
|
||||
depth = 1.0 - depth
|
||||
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
return depth_image
|
||||
@@ -18,14 +18,14 @@ import tempfile
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
from shark.iree_utils.metal_utils import get_metal_target_triple
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc, get_iree_rocm_args
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
@@ -78,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), model_name, extra_args
|
||||
os.getcwd(), model_name, extra_args, debug=args.compile_debug
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -118,7 +118,7 @@ def compile_through_fx(
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
use_tuned=False,
|
||||
save_dir=tempfile.gettempdir(),
|
||||
save_dir="",
|
||||
debug=False,
|
||||
generate_vmfb=True,
|
||||
extra_args=None,
|
||||
@@ -154,8 +154,8 @@ def compile_through_fx(
|
||||
f16_input_mask=f16_input_mask,
|
||||
debug=debug,
|
||||
model_name=extended_model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
|
||||
if use_tuned:
|
||||
if "vae" in extended_model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
@@ -168,6 +168,14 @@ def compile_through_fx(
|
||||
mlir_module, extended_model_name, base_model_id
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_dir):
|
||||
save_dir = ""
|
||||
|
||||
mlir_module = save_mlir(
|
||||
mlir_module,
|
||||
model_name=extended_model_name,
|
||||
dir=save_dir,
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device if device is None else device,
|
||||
@@ -179,17 +187,22 @@ def compile_through_fx(
|
||||
mlir_module,
|
||||
)
|
||||
|
||||
del mlir_module
|
||||
gc.collect()
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if args.device_allocator_heap_key:
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={args.device_allocator_heap_key}",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
@@ -464,18 +477,38 @@ def get_available_devices():
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
# for drivers with single devices
|
||||
# let the default device be selected without any indexing
|
||||
if len(device_list_dict) == 1:
|
||||
device_list.append(f"{device_name} => {driver_name}")
|
||||
else:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name}://{i}"
|
||||
)
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
vulkan_devices = get_devices_by_name("vulkan")
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
@@ -499,16 +532,17 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
if "rocm" in args.device:
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
print(iree_flags)
|
||||
if args.iree_constant_folding == False:
|
||||
iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
iree_flags.append(
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
)
|
||||
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
if args.data_tiling == False:
|
||||
iree_flags.append("--iree-opt-data-tiling=False")
|
||||
|
||||
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
@@ -531,6 +565,10 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"specified_compilation_flags"
|
||||
][device]
|
||||
if "vae" not in model:
|
||||
# Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# dims before dispatch formation right now.
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
return iree_flags
|
||||
|
||||
|
||||
@@ -572,7 +610,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
)
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=custom_weights,
|
||||
checkpoint_path_or_dict=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
num_in_channels=num_in_channels,
|
||||
@@ -779,11 +817,12 @@ def batch_seeds(
|
||||
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
|
||||
|
||||
if repeatable:
|
||||
# set seed for the rng based on what we have so far
|
||||
saved_random_state = random_getstate()
|
||||
if all(seed < 0 for seed in seeds):
|
||||
seeds[0] = sanitize_seed(seeds[0])
|
||||
seed_random(str(seeds))
|
||||
|
||||
# set seed for the rng based on what we have so far
|
||||
saved_random_state = random_getstate()
|
||||
seed_random(str([n for n in seeds if n > -1]))
|
||||
|
||||
# generate any seeds that are unspecified
|
||||
seeds = [sanitize_seed(seed) for seed in seeds]
|
||||
@@ -822,6 +861,8 @@ def clear_all():
|
||||
elif os.name == "unix":
|
||||
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
if args.local_tank_cache != "":
|
||||
shutil.rmtree(args.local_tank_cache)
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
@@ -867,6 +908,13 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if args.write_metadata_to_png:
|
||||
# Using a conditional expression caused problems, so setting a new
|
||||
# variable for now.
|
||||
if args.use_hiresfix:
|
||||
png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}"
|
||||
else:
|
||||
png_size_text = f"{args.width}x{args.height}"
|
||||
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}"
|
||||
@@ -875,7 +923,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
f"Sampler: {args.scheduler}, "
|
||||
f"CFG scale: {args.guidance_scale}, "
|
||||
f"Seed: {img_seed},"
|
||||
f"Size: {args.width}x{args.height}, "
|
||||
f"Size: {png_size_text}, "
|
||||
f"Model: {img_model}, "
|
||||
f"VAE: {img_vae}, "
|
||||
f"LoRA: {img_lora}",
|
||||
@@ -902,8 +950,10 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"HEIGHT": args.height,
|
||||
"WIDTH": args.width,
|
||||
"HEIGHT": args.height
|
||||
if not args.use_hiresfix
|
||||
else args.hiresfix_height,
|
||||
"WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
"VAE": img_vae,
|
||||
@@ -941,6 +991,10 @@ def get_generation_text_info(seeds, device):
|
||||
)
|
||||
text_output += (
|
||||
f"\nsize={args.height}x{args.width}, "
|
||||
if not args.use_hiresfix
|
||||
else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, "
|
||||
)
|
||||
text_output += (
|
||||
f"batch_count={args.batch_count}, "
|
||||
f"batch_size={args.batch_size}, "
|
||||
f"max_length={args.max_length}"
|
||||
|
||||
@@ -19,6 +19,9 @@ a = Analysis(
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
module_collection_mode={
|
||||
'gradio': 'py', # Collect gradio package as source .py files
|
||||
},
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
|
||||
1
apps/stable_diffusion/web/api/__init__.py
Normal file
1
apps/stable_diffusion/web/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from apps.stable_diffusion.web.api.sdapi_v1 import sdapi
|
||||
579
apps/stable_diffusion/web/api/sdapi_v1.py
Normal file
579
apps/stable_diffusion/web/api/sdapi_v1.py
Normal file
@@ -0,0 +1,579 @@
|
||||
import os
|
||||
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel, Field, conlist, model_validator
|
||||
|
||||
from apps.stable_diffusion.web.api.utils import (
|
||||
frozen_args,
|
||||
sampler_aliases,
|
||||
encode_pil_to_base64,
|
||||
decode_base64_to_image,
|
||||
get_model_from_request,
|
||||
get_scheduler_from_request,
|
||||
get_lora_params,
|
||||
get_device,
|
||||
GenerationInputData,
|
||||
GenerationResponseData,
|
||||
)
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_files,
|
||||
get_custom_model_pathfile,
|
||||
predefined_models,
|
||||
predefined_paint_models,
|
||||
predefined_upscaler_models,
|
||||
scheduler_list,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import txt2img_inf
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import img2img_inf
|
||||
from apps.stable_diffusion.web.ui.inpaint_ui import inpaint_inf
|
||||
from apps.stable_diffusion.web.ui.outpaint_ui import outpaint_inf
|
||||
from apps.stable_diffusion.web.ui.upscaler_ui import upscaler_inf
|
||||
|
||||
sdapi = FastAPI()
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/sd-models (lists available models)
|
||||
class AppParam(str, Enum):
|
||||
txt2img = "txt2img"
|
||||
img2img = "img2img"
|
||||
inpaint = "inpaint"
|
||||
outpaint = "outpaint"
|
||||
upscaler = "upscaler"
|
||||
|
||||
|
||||
@sdapi.get(
|
||||
"/v1/sd-models",
|
||||
summary="lists available models",
|
||||
description=(
|
||||
"This is all the models that this server currently knows about.\n "
|
||||
"Models listed may still have a compilation and build pending that "
|
||||
"will be triggered the first time they are used."
|
||||
),
|
||||
)
|
||||
def sd_models_api(app: AppParam = frozen_args.app):
|
||||
match app:
|
||||
case "inpaint" | "outpaint":
|
||||
checkpoint_type = "inpainting"
|
||||
predefined = predefined_paint_models
|
||||
case "upscaler":
|
||||
checkpoint_type = "upscaler"
|
||||
predefined = predefined_upscaler_models
|
||||
case _:
|
||||
checkpoint_type = ""
|
||||
predefined = predefined_models
|
||||
|
||||
return [
|
||||
{
|
||||
"title": model_file,
|
||||
"model_name": model_file,
|
||||
"hash": None,
|
||||
"sha256": None,
|
||||
"filename": get_custom_model_pathfile(model_file),
|
||||
"config": None,
|
||||
}
|
||||
for model_file in get_custom_model_files(
|
||||
custom_checkpoint_type=checkpoint_type
|
||||
)
|
||||
] + [
|
||||
{
|
||||
"title": model,
|
||||
"model_name": model,
|
||||
"hash": None,
|
||||
"sha256": None,
|
||||
"filename": None,
|
||||
"config": None,
|
||||
}
|
||||
for model in predefined
|
||||
]
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/samplers (lists schedulers)
|
||||
@sdapi.get(
|
||||
"/v1/samplers",
|
||||
summary="lists available schedulers/samplers",
|
||||
description=(
|
||||
"These are all the Schedulers defined and available. Not "
|
||||
"every scheduler is compatible with all apis. Aliases are "
|
||||
"equivalent samplers in A1111 if they are known."
|
||||
),
|
||||
)
|
||||
def sd_samplers_api():
|
||||
reverse_sampler_aliases = defaultdict(list)
|
||||
for key, value in sampler_aliases.items():
|
||||
reverse_sampler_aliases[value].append(key)
|
||||
|
||||
return (
|
||||
{
|
||||
"name": scheduler,
|
||||
"aliases": reverse_sampler_aliases.get(scheduler, []),
|
||||
"options": {},
|
||||
}
|
||||
for scheduler in scheduler_list
|
||||
)
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/options (lists application level options)
|
||||
@sdapi.get(
|
||||
"/v1/options",
|
||||
summary="lists current settings of application level options",
|
||||
description=(
|
||||
"A subset of the command line arguments set at startup renamed "
|
||||
"to correspond to the A1111 naming. Only a small subset of A1111 "
|
||||
"options are returned."
|
||||
),
|
||||
)
|
||||
def options_api():
|
||||
# This is mostly just enough to support what Koboldcpp wants, with a
|
||||
# few other things that seemed obvious
|
||||
return {
|
||||
"samples_save": True,
|
||||
"samples_format": frozen_args.output_img_format,
|
||||
"sd_model_checkpoint": os.path.basename(frozen_args.ckpt_loc)
|
||||
if frozen_args.ckpt_loc
|
||||
else frozen_args.hf_model_id,
|
||||
"sd_lora": frozen_args.use_lora,
|
||||
"sd_vae": frozen_args.custom_vae or "Automatic",
|
||||
"enable_pnginfo": frozen_args.write_metadata_to_png,
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/cmd-flags (lists command line argument settings)
|
||||
@sdapi.get(
|
||||
"/v1/cmd-flags",
|
||||
summary="lists the command line arguments value that were set on startup.",
|
||||
)
|
||||
def cmd_flags_api():
|
||||
return vars(frozen_args)
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/txt2img (Text to image)
|
||||
class ModelOverrideSettings(BaseModel):
|
||||
sd_model_checkpoint: str = get_model_from_request(
|
||||
fallback_model="stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
|
||||
|
||||
class Txt2ImgInputData(GenerationInputData):
|
||||
enable_hr: bool = frozen_args.use_hiresfix
|
||||
hr_resize_y: int = Field(
|
||||
default=frozen_args.hiresfix_height, ge=128, le=768, multiple_of=8
|
||||
)
|
||||
hr_resize_x: int = Field(
|
||||
default=frozen_args.hiresfix_width, ge=128, le=768, multiple_of=8
|
||||
)
|
||||
override_settings: ModelOverrideSettings = None
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/txt2img",
|
||||
summary="Does text to image generation",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def txt2img_api(InputData: Txt2ImgInputData):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
fallback_model="stabilityai/stable-diffusion-2-1-base",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(
|
||||
InputData, "txt2img_hires" if InputData.enable_hr else "txt2img"
|
||||
)
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed},"
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}. "
|
||||
)
|
||||
|
||||
res = txt2img_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
use_hiresfix=InputData.enable_hr,
|
||||
hiresfix_height=InputData.hr_resize_y,
|
||||
hiresfix_width=InputData.hr_resize_x,
|
||||
hiresfix_strength=frozen_args.hiresfix_strength,
|
||||
resample_type=frozen_args.resample_type,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/img2img (Image to image)
|
||||
class StencilParam(str, Enum):
|
||||
canny = "canny"
|
||||
openpose = "openpose"
|
||||
scribble = "scribble"
|
||||
zoedepth = "zoedepth"
|
||||
|
||||
|
||||
class Img2ImgInputData(GenerationInputData):
|
||||
init_images: conlist(str, min_length=1, max_length=2)
|
||||
denoising_strength: float = frozen_args.strength
|
||||
use_stencil: StencilParam = frozen_args.use_stencil
|
||||
override_settings: ModelOverrideSettings = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_image_supplied_for_scribble_stencil(self) -> "Img2ImgInputData":
|
||||
if (
|
||||
self.use_stencil == StencilParam.scribble
|
||||
and len(self.init_images) < 2
|
||||
):
|
||||
raise ValueError(
|
||||
"a second image must be supplied for the controlnet:scribble stencil"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/img2img",
|
||||
summary="Does image to image generation",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def img2img_api(
|
||||
InputData: Img2ImgInputData,
|
||||
):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
fallback_model="stabilityai/stable-diffusion-2-1-base",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "img2img")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
mask_image = (
|
||||
decode_base64_to_image(InputData.init_images[1])
|
||||
if len(InputData.init_images) > 1
|
||||
else None
|
||||
)
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed}, "
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}."
|
||||
)
|
||||
|
||||
res = img2img_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
{"image": init_image, "mask": mask_image},
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
InputData.denoising_strength,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
use_stencil=InputData.use_stencil,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
resample_type=frozen_args.resample_type,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/inpaint (Inpainting)
|
||||
class PaintModelOverideSettings(BaseModel):
|
||||
sd_model_checkpoint: str = get_model_from_request(
|
||||
checkpoint_type="inpainting",
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
|
||||
|
||||
class InpaintInputData(GenerationInputData):
|
||||
image: str = Field(description="Base64 encoded input image")
|
||||
mask: str = Field(description="Base64 encoded mask image")
|
||||
is_full_res: bool = False # Is this setting backwards in the UI?
|
||||
full_res_padding: int = Field(default=32, ge=0, le=256, multiple_of=4)
|
||||
denoising_strength: float = frozen_args.strength
|
||||
use_stencil: StencilParam = frozen_args.use_stencil
|
||||
override_settings: PaintModelOverideSettings = None
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/inpaint",
|
||||
summary="Does inpainting generation on an image",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def inpaint_api(
|
||||
InputData: InpaintInputData,
|
||||
):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
checkpoint_type="inpainting",
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "inpaint")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.image)
|
||||
mask = decode_base64_to_image(InputData.mask)
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f'Negative Prompt: {InputData.negative_prompt}", '
|
||||
f'Seed: {InputData.seed}", '
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}."
|
||||
)
|
||||
|
||||
res = inpaint_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
{"image": init_image, "mask": mask},
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.is_full_res,
|
||||
InputData.full_res_padding,
|
||||
InputData.steps,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/outpaint (Outpainting)
|
||||
class DirectionParam(str, Enum):
|
||||
left = "left"
|
||||
right = "right"
|
||||
up = "up"
|
||||
down = "down"
|
||||
|
||||
|
||||
class OutpaintInputData(GenerationInputData):
|
||||
init_images: list[str]
|
||||
pixels: int = Field(
|
||||
default=frozen_args.pixels, ge=8, le=256, multiple_of=8
|
||||
)
|
||||
mask_blur: int = Field(default=frozen_args.mask_blur, ge=0, le=64)
|
||||
directions: set[DirectionParam] = [
|
||||
direction
|
||||
for direction in ["left", "right", "up", "down"]
|
||||
if vars(frozen_args)[direction]
|
||||
]
|
||||
noise_q: float = frozen_args.noise_q
|
||||
color_variation: float = frozen_args.color_variation
|
||||
override_settings: PaintModelOverideSettings = None
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/outpaint",
|
||||
summary="Does outpainting generation on an image",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def outpaint_api(
|
||||
InputData: OutpaintInputData,
|
||||
):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
checkpoint_type="inpainting",
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "outpaint")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed}, "
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}."
|
||||
)
|
||||
|
||||
res = outpaint_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
init_image,
|
||||
InputData.pixels,
|
||||
InputData.mask_blur,
|
||||
InputData.directions,
|
||||
InputData.noise_q,
|
||||
InputData.color_variation,
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/upscaler (Upscaling)
|
||||
class UpscalerModelOverideSettings(BaseModel):
|
||||
sd_model_checkpoint: str = get_model_from_request(
|
||||
checkpoint_type="upscaler",
|
||||
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
|
||||
)
|
||||
|
||||
|
||||
class UpscalerInputData(GenerationInputData):
|
||||
init_images: list[str] = Field(
|
||||
description="Base64 encoded image to upscale"
|
||||
)
|
||||
noise_level: int = frozen_args.noise_level
|
||||
override_settings: UpscalerModelOverideSettings = None
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/upscaler",
|
||||
summary="Does image upscaling",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def upscaler_api(
|
||||
InputData: UpscalerInputData,
|
||||
):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
checkpoint_type="upscaler",
|
||||
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "upscaler")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed}, "
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}."
|
||||
)
|
||||
|
||||
res = upscaler_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
init_image,
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
InputData.noise_level,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
211
apps/stable_diffusion/web/api/utils.py
Normal file
211
apps/stable_diffusion/web/api/utils.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import base64
|
||||
import pickle
|
||||
|
||||
from argparse import Namespace
|
||||
from fastapi.exceptions import HTTPException
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
get_custom_model_files,
|
||||
predefined_models,
|
||||
predefined_paint_models,
|
||||
predefined_upscaler_models,
|
||||
scheduler_list,
|
||||
scheduler_list_cpu_only,
|
||||
)
|
||||
|
||||
|
||||
# Probably overly cautious, but try to ensure we only use the starting
|
||||
# args in each api call, as the code does `args.<whatever> = <changed_value>`
|
||||
# in lots of places and in testing, it seemed to me, these changes leaked
|
||||
# into subsequent api calls.
|
||||
|
||||
# Roundtripping through pickle for deepcopy, there is probably a better way
|
||||
frozen_args = Namespace(**(pickle.loads(pickle.dumps(vars(args)))))
|
||||
|
||||
# an attempt to map some of the A1111 sampler names to scheduler names
|
||||
# https://github.com/huggingface/diffusers/issues/4167 is where the
|
||||
# (not so obvious) ones come from
|
||||
sampler_aliases = {
|
||||
# a1111/onnx (these point to diffusers classes in A1111)
|
||||
"pndm": "PNDM",
|
||||
"heun": "HeunDiscrete",
|
||||
"ddim": "DDIM",
|
||||
"ddpm": "DDPM",
|
||||
"euler": "EulerDiscrete",
|
||||
"euler-ancestral": "EulerAncestralDiscrete",
|
||||
"dpm": "DPMSolverMultistep",
|
||||
# a1111/k_diffusion (the obvious ones)
|
||||
"Euler a": "EulerAncestralDiscrete",
|
||||
"Euler": "EulerDiscrete",
|
||||
"LMS": "LMSDiscrete",
|
||||
"Heun": "HeunDiscrete",
|
||||
# a1111/k_diffusion (not so obvious)
|
||||
"DPM++ 2M": "DPMSolverMultistep",
|
||||
"DPM++ 2M Karras": "DPMSolverMultistepKarras",
|
||||
"DPM++ 2M SDE": "DPMSolverMultistep++",
|
||||
"DPM++ 2M SDE Karras": "DPMSolverMultistepKarras++",
|
||||
"DPM2": "KDPM2Discrete",
|
||||
"DPM2 a": "KDPM2AncestralDiscrete",
|
||||
}
|
||||
|
||||
allowed_schedulers = {
|
||||
"txt2img": {
|
||||
"schedulers": scheduler_list,
|
||||
"fallback": "SharkEulerDiscrete",
|
||||
},
|
||||
"txt2img_hires": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "DEISMultistep",
|
||||
},
|
||||
"img2img": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "EulerDiscrete",
|
||||
},
|
||||
"inpaint": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "DDIM",
|
||||
},
|
||||
"outpaint": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "DDIM",
|
||||
},
|
||||
"upscaler": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "DDIM",
|
||||
},
|
||||
}
|
||||
|
||||
# base pydantic model for sd generation apis
|
||||
|
||||
|
||||
class GenerationInputData(BaseModel):
|
||||
prompt: str = ""
|
||||
negative_prompt: str = ""
|
||||
hf_model_id: str | None = None
|
||||
height: int = Field(
|
||||
default=frozen_args.height, ge=128, le=768, multiple_of=8
|
||||
)
|
||||
width: int = Field(
|
||||
default=frozen_args.width, ge=128, le=768, multiple_of=8
|
||||
)
|
||||
sampler_name: str = frozen_args.scheduler
|
||||
cfg_scale: float = Field(default=frozen_args.guidance_scale, ge=1)
|
||||
steps: int = Field(default=frozen_args.steps, ge=1, le=100)
|
||||
seed: int = frozen_args.seed
|
||||
n_iter: int = Field(default=frozen_args.batch_count)
|
||||
|
||||
|
||||
class GenerationResponseData(BaseModel):
|
||||
images: list[str] = Field(description="Generated images, Base64 encoded")
|
||||
properties: dict = {}
|
||||
info: str
|
||||
|
||||
|
||||
# image encoding/decoding
|
||||
|
||||
|
||||
def encode_pil_to_base64(images: list[Image.Image]):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if frozen_args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif frozen_args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding: str):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=400, detail="Invalid encoded image")
|
||||
|
||||
|
||||
# get valid sd models/vaes/schedulers etc.
|
||||
|
||||
|
||||
def get_predefined_models(custom_checkpoint_type: str):
|
||||
match custom_checkpoint_type:
|
||||
case "inpainting":
|
||||
return predefined_paint_models
|
||||
case "upscaler":
|
||||
return predefined_upscaler_models
|
||||
case _:
|
||||
return predefined_models
|
||||
|
||||
|
||||
def get_model_from_request(
|
||||
request_data=None,
|
||||
checkpoint_type: str = "",
|
||||
fallback_model: str = "",
|
||||
):
|
||||
model = None
|
||||
if request_data:
|
||||
if request_data.hf_model_id:
|
||||
model = request_data.hf_model_id
|
||||
elif request_data.override_settings:
|
||||
model = request_data.override_settings.sd_model_checkpoint
|
||||
|
||||
# if the request didn't specify a model try the command line args
|
||||
result = model or frozen_args.ckpt_loc or frozen_args.hf_model_id
|
||||
|
||||
# make sure whatever we have is a valid model for the checkpoint type
|
||||
if result in get_custom_model_files(
|
||||
custom_checkpoint_type=checkpoint_type
|
||||
) + get_predefined_models(checkpoint_type):
|
||||
return result
|
||||
# if not return what was specified as the fallback
|
||||
else:
|
||||
return fallback_model
|
||||
|
||||
|
||||
def get_scheduler_from_request(
|
||||
request_data: GenerationInputData, operation: str
|
||||
):
|
||||
allowed = allowed_schedulers[operation]
|
||||
|
||||
requested = request_data.sampler_name
|
||||
requested = sampler_aliases.get(requested, requested)
|
||||
|
||||
return (
|
||||
requested
|
||||
if requested in allowed["schedulers"]
|
||||
else allowed["fallback"]
|
||||
)
|
||||
|
||||
|
||||
def get_lora_params(use_lora: str):
|
||||
# TODO: since the inference functions in the webui, which we are
|
||||
# still calling into for the api, jam these back together again before
|
||||
# handing them off to the pipeline, we should remove this nonsense
|
||||
# and unify their selection in the UI and command line args proper
|
||||
if use_lora in get_custom_model_files("lora"):
|
||||
return (use_lora, "")
|
||||
|
||||
return ("None", use_lora)
|
||||
|
||||
|
||||
def get_device(device_str: str):
|
||||
# first substring match in the list available devices, with first
|
||||
# device when none are matched
|
||||
return next(
|
||||
(device for device in available_devices if device_str in device),
|
||||
available_devices[0],
|
||||
)
|
||||
@@ -1,6 +1,8 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
from multiprocessing import freeze_support
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import apps.stable_diffusion.web.utils.app as app
|
||||
|
||||
if sys.platform == "darwin":
|
||||
# import before IREE to avoid torch-MLIR library issues
|
||||
@@ -20,78 +22,71 @@ if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
|
||||
def launch_app(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
|
||||
window = Tk()
|
||||
|
||||
# get screen width and height of display and make it more reasonably
|
||||
# sized as we aren't making it full-screen or maximized
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.debug:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
if args.api or "api" in args.ui.split(","):
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_api,
|
||||
img2img_api,
|
||||
upscaler_api,
|
||||
inpaint_api,
|
||||
outpaint_api,
|
||||
llm_chat_api,
|
||||
)
|
||||
from apps.stable_diffusion.web.api import sdapi
|
||||
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
# init global sd pipeline and config
|
||||
global_obj._init()
|
||||
|
||||
app = FastAPI()
|
||||
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
|
||||
api = FastAPI()
|
||||
api.mount("/sdapi/", sdapi)
|
||||
|
||||
# chat APIs needed for compatibility with multiple extensions using OpenAI API
|
||||
app.add_api_route(
|
||||
api.add_api_route(
|
||||
"/v1/chat/completions", llm_chat_api, methods=["post"]
|
||||
)
|
||||
app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
|
||||
app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
|
||||
app.add_api_route("/completions", llm_chat_api, methods=["post"])
|
||||
app.add_api_route(
|
||||
api.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
|
||||
api.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
|
||||
api.add_api_route("/completions", llm_chat_api, methods=["post"])
|
||||
api.add_api_route(
|
||||
"/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
|
||||
)
|
||||
app.include_router(APIRouter())
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.server_port)
|
||||
api.include_router(APIRouter())
|
||||
|
||||
# deal with CORS requests if CORS accept origins are set
|
||||
if args.api_accept_origin:
|
||||
print(
|
||||
f"API Configured for CORS. Accepting origins: { args.api_accept_origin }"
|
||||
)
|
||||
api.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.api_accept_origin,
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
else:
|
||||
print("API not configured for CORS")
|
||||
|
||||
uvicorn.run(api, host="0.0.0.0", port=args.server_port)
|
||||
sys.exit(0)
|
||||
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
config_gradio_tmp_imgs_folder,
|
||||
from apps.stable_diffusion.web.utils.tmp_configs import (
|
||||
config_tmp,
|
||||
)
|
||||
|
||||
config_gradio_tmp_imgs_folder()
|
||||
config_tmp()
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
create_custom_models_folders,
|
||||
nodicon_loc,
|
||||
)
|
||||
|
||||
create_custom_models_folders()
|
||||
|
||||
@@ -102,12 +97,9 @@ if __name__ == "__main__":
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
@@ -115,11 +107,20 @@ if __name__ == "__main__":
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
# SDXL
|
||||
txt2img_sdxl_web,
|
||||
txt2img_sdxl_custom_model,
|
||||
txt2img_sdxl_gallery,
|
||||
txt2img_sdxl_png_info_img,
|
||||
txt2img_sdxl_status,
|
||||
txt2img_sdxl_sendto_img2img,
|
||||
txt2img_sdxl_sendto_inpaint,
|
||||
txt2img_sdxl_sendto_outpaint,
|
||||
txt2img_sdxl_sendto_upscaler,
|
||||
# h2ogpt_upload,
|
||||
# h2ogpt_web,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
@@ -128,7 +129,6 @@ if __name__ == "__main__":
|
||||
img2img_sendto_upscaler,
|
||||
inpaint_web,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
@@ -137,7 +137,6 @@ if __name__ == "__main__":
|
||||
inpaint_sendto_upscaler,
|
||||
outpaint_web,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
@@ -146,15 +145,14 @@ if __name__ == "__main__":
|
||||
outpaint_sendto_upscaler,
|
||||
upscaler_web,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
lora_train_web,
|
||||
model_web,
|
||||
# lora_train_web,
|
||||
# model_web,
|
||||
model_config_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
@@ -169,6 +167,7 @@ if __name__ == "__main__":
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
@@ -182,7 +181,7 @@ if __name__ == "__main__":
|
||||
button.click(
|
||||
lambda x: (
|
||||
x[0]["name"] if len(x) != 0 else None,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
gr.Tabs(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
@@ -193,7 +192,7 @@ if __name__ == "__main__":
|
||||
lambda x: (
|
||||
"None",
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
gr.Tabs(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
@@ -203,14 +202,16 @@ if __name__ == "__main__":
|
||||
button.click(
|
||||
lambda x: (
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
gr.Tabs(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
|
||||
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
|
||||
) as sd_web:
|
||||
with gr.Tabs() as tabs:
|
||||
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
|
||||
@@ -245,24 +246,36 @@ if __name__ == "__main__":
|
||||
inpaint_status,
|
||||
outpaint_status,
|
||||
upscaler_status,
|
||||
txt2img_sdxl_status,
|
||||
]
|
||||
)
|
||||
with gr.TabItem(label="Model Manager", id=6):
|
||||
model_web.render()
|
||||
with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot (Experimental)", id=8):
|
||||
# with gr.TabItem(label="Model Manager", id=6):
|
||||
# model_web.render()
|
||||
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
# lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=8):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(
|
||||
label="Generate Sharding Config (Experimental)", id=9
|
||||
):
|
||||
model_config_web.render()
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
minigpt4_web.render()
|
||||
# with gr.TabItem(
|
||||
# label="Generate Sharding Config (Experimental)", id=9
|
||||
# ):
|
||||
# model_config_web.render()
|
||||
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
# minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
# h2ogpt_upload.render()
|
||||
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
|
||||
# h2ogpt_web.render()
|
||||
with gr.TabItem(label="Text-to-Image (SDXL)", id=13):
|
||||
txt2img_sdxl_web.render()
|
||||
|
||||
actual_port = app.usable_port()
|
||||
if actual_port != args.server_port:
|
||||
sd_web.load(
|
||||
fn=lambda: gr.Info(
|
||||
f"Port {args.server_port} is in use by another application. "
|
||||
f"Shark is running on port {actual_port} instead."
|
||||
)
|
||||
)
|
||||
|
||||
# send to buttons
|
||||
register_button_click(
|
||||
@@ -392,46 +405,48 @@ if __name__ == "__main__":
|
||||
[outputgallery_filename],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
0,
|
||||
[outputgallery_filename],
|
||||
[txt2img_sdxl_png_info_img, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_txt2img,
|
||||
0,
|
||||
[hf_models],
|
||||
[txt2img_custom_model, txt2img_hf_model_id, tabs],
|
||||
[txt2img_custom_model, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_img2img,
|
||||
1,
|
||||
[hf_models],
|
||||
[img2img_custom_model, img2img_hf_model_id, tabs],
|
||||
[img2img_custom_model, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_inpaint,
|
||||
2,
|
||||
[hf_models],
|
||||
[inpaint_custom_model, inpaint_hf_model_id, tabs],
|
||||
[inpaint_custom_model, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_outpaint,
|
||||
3,
|
||||
[hf_models],
|
||||
[outpaint_custom_model, outpaint_hf_model_id, tabs],
|
||||
[outpaint_custom_model, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_upscaler,
|
||||
4,
|
||||
[hf_models],
|
||||
[upscaler_custom_model, upscaler_hf_model_id, tabs],
|
||||
[upscaler_custom_model, tabs],
|
||||
)
|
||||
|
||||
sd_web.queue()
|
||||
if args.ui == "app":
|
||||
t = Process(
|
||||
target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
)
|
||||
t.start()
|
||||
sd_web.launch(
|
||||
share=args.share,
|
||||
inbrowser=args.ui == "web",
|
||||
inbrowser=not app.launch(actual_port),
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.server_port,
|
||||
server_port=actual_port,
|
||||
favicon_path=nodicon_loc,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
txt2img_inf,
|
||||
txt2img_api,
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
@@ -12,12 +10,22 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.txt2img_sdxl_ui import (
|
||||
txt2img_sdxl_inf,
|
||||
txt2img_sdxl_web,
|
||||
txt2img_sdxl_custom_model,
|
||||
txt2img_sdxl_gallery,
|
||||
txt2img_sdxl_status,
|
||||
txt2img_sdxl_png_info_img,
|
||||
txt2img_sdxl_sendto_img2img,
|
||||
txt2img_sdxl_sendto_inpaint,
|
||||
txt2img_sdxl_sendto_outpaint,
|
||||
txt2img_sdxl_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_inf,
|
||||
img2img_api,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
@@ -27,10 +35,8 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
inpaint_inf,
|
||||
inpaint_api,
|
||||
inpaint_web,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
@@ -40,10 +46,8 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
outpaint_inf,
|
||||
outpaint_api,
|
||||
outpaint_web,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
@@ -53,10 +57,8 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.upscaler_ui import (
|
||||
upscaler_inf,
|
||||
upscaler_api,
|
||||
upscaler_web,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
@@ -86,6 +88,7 @@ from apps.stable_diffusion.web.ui.outputgallery_ui import (
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
|
||||
55
apps/stable_diffusion/web/ui/common_ui_events.py
Normal file
55
apps/stable_diffusion/web/ui/common_ui_events.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
HSLHue,
|
||||
hsl_color,
|
||||
get_lora_metadata,
|
||||
)
|
||||
|
||||
|
||||
# Answers HTML to show the most frequent tags used when a LoRA was trained,
|
||||
# taken from the metadata of its .safetensors file.
|
||||
def lora_changed(lora_file):
|
||||
# tag frequency percentage, that gets maximum amount of the staring hue
|
||||
TAG_COLOR_THRESHOLD = 0.55
|
||||
# tag frequency percentage, above which a tag is displayed
|
||||
TAG_DISPLAY_THRESHOLD = 0.65
|
||||
# template for the html used to display a tag
|
||||
TAG_HTML_TEMPLATE = '<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
|
||||
|
||||
if lora_file == "None":
|
||||
return ["<div><i>No LoRA selected</i></div>"]
|
||||
elif not lora_file.lower().endswith(".safetensors"):
|
||||
return [
|
||||
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
|
||||
]
|
||||
else:
|
||||
metadata = get_lora_metadata(lora_file)
|
||||
if metadata:
|
||||
frequencies = metadata["frequencies"]
|
||||
return [
|
||||
"".join(
|
||||
[
|
||||
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
|
||||
]
|
||||
+ [
|
||||
TAG_HTML_TEMPLATE.format(
|
||||
color=hsl_color(
|
||||
(tag[1] - TAG_COLOR_THRESHOLD)
|
||||
/ (1 - TAG_COLOR_THRESHOLD),
|
||||
start=HSLHue.RED,
|
||||
end=HSLHue.GREEN,
|
||||
),
|
||||
tag=tag[0],
|
||||
)
|
||||
for tag in frequencies
|
||||
if tag[1] > TAG_DISPLAY_THRESHOLD
|
||||
],
|
||||
)
|
||||
]
|
||||
elif metadata is None:
|
||||
return [
|
||||
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
|
||||
]
|
||||
else:
|
||||
return [
|
||||
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
|
||||
]
|
||||
@@ -105,6 +105,18 @@ body {
|
||||
background-color: var(--background-fill-primary);
|
||||
}
|
||||
|
||||
.generating.svelte-zlszon.svelte-zlszon {
|
||||
border: none;
|
||||
}
|
||||
|
||||
.generating {
|
||||
border: none !important;
|
||||
}
|
||||
|
||||
#chatbot {
|
||||
height: 100% !important;
|
||||
}
|
||||
|
||||
/* display in full width for desktop devices */
|
||||
@media (min-width: 1536px)
|
||||
{
|
||||
@@ -246,10 +258,39 @@ footer {
|
||||
background-color: var(--block-label-background-fill);
|
||||
}
|
||||
|
||||
/* lora tag pills */
|
||||
.lora-tags {
|
||||
border: 1px solid var(--border-color-primary);
|
||||
color: var(--block-info-text-color) !important;
|
||||
padding: var(--block-padding);
|
||||
}
|
||||
|
||||
.lora-tag {
|
||||
display: inline-block;
|
||||
height: 2em;
|
||||
color: rgb(212 212 212) !important;
|
||||
margin-right: 5pt;
|
||||
margin-bottom: 5pt;
|
||||
padding: 2pt 5pt;
|
||||
border-radius: 5pt;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.lora-model {
|
||||
margin-bottom: var(--spacing-lg);
|
||||
color: var(--block-info-text-color) !important;
|
||||
line-height: var(--line-sm);
|
||||
}
|
||||
|
||||
/* output gallery tab */
|
||||
.output_parameters_dataframe table.table {
|
||||
/* works around a gradio bug that always shows scrollbars */
|
||||
overflow: clip auto;
|
||||
}
|
||||
|
||||
.output_parameters_dataframe tbody td {
|
||||
font-size: small;
|
||||
line-height: var(--line-xs)
|
||||
line-height: var(--line-xs);
|
||||
}
|
||||
|
||||
.output_icon_button {
|
||||
|
||||
@@ -212,6 +212,7 @@ with gr.Blocks(title="DocuChat") as h2ogpt_web:
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
|
||||
@@ -3,10 +3,15 @@ import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
import PIL
|
||||
from math import ceil
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from gradio.components.image_editor import (
|
||||
Brush,
|
||||
Eraser,
|
||||
EditorData,
|
||||
EditorValue,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -16,6 +21,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
@@ -29,6 +35,12 @@ from apps.stable_diffusion.src import (
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
resampler_list,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.stencils import (
|
||||
CannyDetector,
|
||||
OpenposeDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
import numpy as np
|
||||
@@ -54,19 +66,21 @@ def img2img_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
use_stencil: str,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
control_mode: str,
|
||||
stencils: list,
|
||||
images: list,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -88,34 +102,39 @@ def img2img_inf(
|
||||
args.img_path = "not none"
|
||||
args.ondemand = ondemand
|
||||
|
||||
if image_dict is None:
|
||||
for i, stencil in enumerate(stencils):
|
||||
if images[i] is None and stencil is not None:
|
||||
return
|
||||
if images[i] is not None:
|
||||
if isinstance(images[i], dict):
|
||||
images[i] = images[i]["composite"]
|
||||
images[i] = images[i].convert("RGB")
|
||||
|
||||
if image_dict is None and images[0] is None:
|
||||
return None, "An Initial Image is required"
|
||||
if use_stencil == "scribble":
|
||||
image = image_dict["mask"].convert("RGB")
|
||||
elif isinstance(image_dict, PIL.Image.Image):
|
||||
if isinstance(image_dict, PIL.Image.Image):
|
||||
image = image_dict.convert("RGB")
|
||||
else:
|
||||
elif image_dict:
|
||||
image = image_dict["image"].convert("RGB")
|
||||
else:
|
||||
# TODO: enable t2i + controlnets
|
||||
image = None
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files():
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -126,10 +145,11 @@ def img2img_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
use_stencil = None if use_stencil == "None" else use_stencil
|
||||
args.use_stencil = use_stencil
|
||||
if use_stencil is not None:
|
||||
args.scheduler = "DDIM"
|
||||
stencil_count = 0
|
||||
for stencil in stencils:
|
||||
if stencil is not None:
|
||||
stencil_count += 1
|
||||
if stencil_count > 0:
|
||||
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
|
||||
image, width, height = resize_stencil(image)
|
||||
elif "Shark" in args.scheduler:
|
||||
@@ -153,7 +173,7 @@ def img2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=use_stencil,
|
||||
stencils=stencils,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
@@ -175,12 +195,12 @@ def img2img_inf(
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
else "stabilityai/stable-diffusion-1-5-base"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(args.scheduler)
|
||||
|
||||
if use_stencil is not None:
|
||||
if stencil_count > 0:
|
||||
args.use_tuned = False
|
||||
global_obj.set_sd_obj(
|
||||
StencilPipeline.from_pretrained(
|
||||
@@ -197,7 +217,7 @@ def img2img_inf(
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_stencil=use_stencil,
|
||||
stencils=stencils,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
@@ -245,7 +265,7 @@ def img2img_inf(
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
ceil(steps / strength),
|
||||
strength,
|
||||
guidance_scale,
|
||||
seeds[current_batch],
|
||||
@@ -254,7 +274,10 @@ def img2img_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
stencils,
|
||||
images,
|
||||
resample_type=resample_type,
|
||||
control_mode=control_mode,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(
|
||||
@@ -274,93 +297,17 @@ def img2img_inf(
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Image-to-Image", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
), stencils, images
|
||||
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Img2Img Rest API.
|
||||
def img2img_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}.'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = img2img_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
init_image,
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["denoising_strength"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
use_stencil=InputData["use_stencil"]
|
||||
if "use_stencil" in InputData.keys()
|
||||
else "None",
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
return generated_imgs, text_output, "", stencils, images
|
||||
|
||||
|
||||
with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
# Stencils
|
||||
# TODO: Add more stencils here
|
||||
STENCIL_COUNT = 2
|
||||
stencils = gr.State([None] * STENCIL_COUNT)
|
||||
images = gr.State([None] * STENCIL_COUNT)
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
@@ -378,31 +325,19 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
i2i_model_info = (str(get_custom_model_path())).replace(
|
||||
"\\", "\n\\"
|
||||
i2i_model_info = (
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
)
|
||||
i2i_model_info = f"Custom Model Path: {i2i_model_info}"
|
||||
img2img_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=i2i_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
img2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: SG161222/Realistic_Vision_V1.3, "
|
||||
"https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
choices=get_custom_model_files() + predefined_models,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
i2i_vae_info = (str(get_custom_model_path("vae"))).replace(
|
||||
@@ -417,6 +352,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -432,72 +369,285 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
# TODO: make this import image prompt info if it exists
|
||||
img2img_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
source="upload",
|
||||
tool="sketch",
|
||||
type="pil",
|
||||
height=300,
|
||||
height=512,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Stencil Options", open=False):
|
||||
with gr.Row():
|
||||
use_stencil = gr.Dropdown(
|
||||
elem_id="stencil_model",
|
||||
label="Stencil model",
|
||||
value="None",
|
||||
choices=["None", "canny", "openpose", "scribble"],
|
||||
)
|
||||
with gr.Accordion(label="Multistencil Options", open=False):
|
||||
choices = [
|
||||
"None",
|
||||
"canny",
|
||||
"openpose",
|
||||
"scribble",
|
||||
"zoedepth",
|
||||
]
|
||||
|
||||
def show_canvas(choice):
|
||||
if choice == "scribble":
|
||||
return (
|
||||
gr.Slider.update(visible=True),
|
||||
gr.Slider.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
def cnet_preview(
|
||||
model, input_image, index, stencils, images
|
||||
):
|
||||
images[index] = input_image
|
||||
stencils[index] = model
|
||||
match model:
|
||||
case "canny":
|
||||
canny = CannyDetector()
|
||||
result = canny(
|
||||
np.array(input_image["composite"]),
|
||||
100,
|
||||
200,
|
||||
)
|
||||
return (
|
||||
Image.fromarray(result),
|
||||
stencils,
|
||||
images,
|
||||
)
|
||||
case "openpose":
|
||||
openpose = OpenposeDetector()
|
||||
result = openpose(
|
||||
np.array(input_image["composite"])
|
||||
)
|
||||
print(result)
|
||||
# TODO: This is just an empty canvas, need to draw the candidates (which are in result[1])
|
||||
return (
|
||||
Image.fromarray(result[0]),
|
||||
stencils,
|
||||
images,
|
||||
)
|
||||
case "zoedepth":
|
||||
zoedepth = ZoeDetector()
|
||||
result = zoedepth(
|
||||
np.array(input_image["composite"])
|
||||
)
|
||||
return (
|
||||
Image.fromarray(result),
|
||||
stencils,
|
||||
images,
|
||||
)
|
||||
case "scribble":
|
||||
return (
|
||||
input_image["composite"],
|
||||
stencils,
|
||||
images,
|
||||
)
|
||||
case _:
|
||||
return (None, stencils, images)
|
||||
|
||||
def create_canvas(width, height):
|
||||
data = Image.fromarray(
|
||||
np.zeros(
|
||||
shape=(height, width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
+ 255
|
||||
)
|
||||
img_dict = {
|
||||
"background": data,
|
||||
"layers": [data],
|
||||
"composite": None,
|
||||
}
|
||||
return EditorValue(img_dict)
|
||||
|
||||
def update_cn_input(model, width, height):
|
||||
if model == "scribble":
|
||||
return [
|
||||
gr.ImageEditor(
|
||||
visible=True,
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
image_mode="RGB",
|
||||
type="pil",
|
||||
value=create_canvas(width, height),
|
||||
brush=Brush(
|
||||
colors=["#000000"], color_mode="fixed"
|
||||
),
|
||||
),
|
||||
gr.Image(
|
||||
visible=True,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
show_download_button=False,
|
||||
),
|
||||
gr.Slider(visible=True),
|
||||
gr.Slider(visible=True),
|
||||
gr.Button(visible=True),
|
||||
]
|
||||
else:
|
||||
return (
|
||||
gr.Slider.update(visible=False),
|
||||
gr.Slider.update(visible=False),
|
||||
gr.Button.update(visible=False),
|
||||
)
|
||||
|
||||
def create_canvas(w, h):
|
||||
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
|
||||
return [
|
||||
gr.ImageEditor(
|
||||
visible=True,
|
||||
image_mode="RGB",
|
||||
type="pil",
|
||||
interactive=True,
|
||||
value=None,
|
||||
),
|
||||
gr.Image(
|
||||
visible=True,
|
||||
show_label=False,
|
||||
interactive=True,
|
||||
show_download_button=False,
|
||||
),
|
||||
gr.Slider(visible=False),
|
||||
gr.Slider(visible=False),
|
||||
gr.Button(visible=False),
|
||||
]
|
||||
|
||||
with gr.Row():
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
with gr.Column():
|
||||
cnet_1 = gr.Button(
|
||||
value="Generate controlnet input"
|
||||
)
|
||||
cnet_1_model = gr.Dropdown(
|
||||
label="Controlnet 1",
|
||||
value="None",
|
||||
choices=choices,
|
||||
)
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
value="Make Canvas!",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
cnet_1_image = gr.ImageEditor(
|
||||
visible=False,
|
||||
image_mode="RGB",
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
type="pil",
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
cnet_1_output = gr.Image(
|
||||
visible=True, show_label=False
|
||||
)
|
||||
|
||||
cnet_1_model.input(
|
||||
update_cn_input,
|
||||
[cnet_1_model, canvas_width, canvas_height],
|
||||
[
|
||||
cnet_1_image,
|
||||
cnet_1_output,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
make_canvas,
|
||||
],
|
||||
)
|
||||
make_canvas.click(
|
||||
update_cn_input,
|
||||
[cnet_1_model, canvas_width, canvas_height],
|
||||
[
|
||||
cnet_1_image,
|
||||
cnet_1_output,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
make_canvas,
|
||||
],
|
||||
)
|
||||
cnet_1.click(
|
||||
fn=(
|
||||
lambda a, b, s, i: cnet_preview(a, b, 0, s, i)
|
||||
),
|
||||
inputs=[
|
||||
cnet_1_model,
|
||||
cnet_1_image,
|
||||
stencils,
|
||||
images,
|
||||
],
|
||||
outputs=[cnet_1_output, stencils, images],
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
cnet_2 = gr.Button(
|
||||
value="Generate controlnet input"
|
||||
)
|
||||
cnet_2_model = gr.Dropdown(
|
||||
label="Controlnet 2",
|
||||
value="None",
|
||||
choices=choices,
|
||||
)
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
value="Make Canvas!",
|
||||
visible=False,
|
||||
)
|
||||
cnet_2_image = gr.ImageEditor(
|
||||
visible=False,
|
||||
image_mode="RGB",
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
type="pil",
|
||||
)
|
||||
create_button = gr.Button(
|
||||
label="Start",
|
||||
value="Open drawing canvas!",
|
||||
visible=False,
|
||||
)
|
||||
create_button.click(
|
||||
fn=create_canvas,
|
||||
inputs=[canvas_width, canvas_height],
|
||||
outputs=[img2img_init_image],
|
||||
)
|
||||
use_stencil.change(
|
||||
fn=show_canvas,
|
||||
inputs=use_stencil,
|
||||
outputs=[canvas_width, canvas_height, create_button],
|
||||
cnet_2_output = gr.Image(
|
||||
visible=True, show_label=False
|
||||
)
|
||||
cnet_2_model.select(
|
||||
update_cn_input,
|
||||
[cnet_2_model, canvas_width, canvas_height],
|
||||
[
|
||||
cnet_2_image,
|
||||
cnet_2_output,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
make_canvas,
|
||||
],
|
||||
)
|
||||
make_canvas.click(
|
||||
update_cn_input,
|
||||
[cnet_2_model, canvas_width, canvas_height],
|
||||
[
|
||||
cnet_2_image,
|
||||
cnet_2_output,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
make_canvas,
|
||||
],
|
||||
)
|
||||
cnet_2.click(
|
||||
fn=(
|
||||
lambda a, b, s, i: cnet_preview(a, b, 1, s, i)
|
||||
),
|
||||
inputs=[
|
||||
cnet_2_model,
|
||||
cnet_2_image,
|
||||
stencils,
|
||||
images,
|
||||
],
|
||||
outputs=[cnet_2_output, stencils, images],
|
||||
)
|
||||
control_mode = gr.Radio(
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
value="Balanced",
|
||||
label="Control Mode",
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
@@ -508,6 +658,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
).replace("\\", "\n\\")
|
||||
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=i2i_lora_info,
|
||||
elem_id="lora_weights",
|
||||
@@ -524,6 +675,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -531,6 +687,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -550,15 +707,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
width = gr.Slider(
|
||||
384, 768, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
@@ -581,11 +729,26 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
step=0.01,
|
||||
label="Denoising Strength",
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
@@ -629,17 +792,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -651,13 +805,26 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{i2i_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
img2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
img2img_sendto_outpaint = gr.Button(
|
||||
@@ -683,20 +850,28 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
use_stencil,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
stencils,
|
||||
images,
|
||||
],
|
||||
outputs=[
|
||||
img2img_gallery,
|
||||
std_output,
|
||||
img2img_status,
|
||||
stencils,
|
||||
images,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output, img2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
)
|
||||
|
||||
@@ -715,3 +890,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -4,9 +4,7 @@ import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -16,6 +14,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
@@ -53,8 +52,7 @@ def inpaint_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -89,21 +87,17 @@ def inpaint_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"):
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -128,7 +122,7 @@ def inpaint_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
@@ -228,86 +222,6 @@ def inpaint_inf(
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Inpaint Rest API.
|
||||
def inpaint_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}.'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["image"])
|
||||
mask = decode_base64_to_image(InputData["mask"])
|
||||
res = inpaint_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
{"image": init_image, "mask": mask},
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["is_full_res"],
|
||||
InputData["full_res_padding"],
|
||||
InputData["steps"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -327,34 +241,21 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
inpaint_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
inpaint_model_info = (
|
||||
f"Custom Model Path: {inpaint_model_info}"
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
)
|
||||
inpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=inpaint_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files(
|
||||
choices=get_custom_model_files(
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
inpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
|
||||
"https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
inpaint_vae_info = (
|
||||
@@ -369,6 +270,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -387,8 +290,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
|
||||
inpaint_init_image = gr.Image(
|
||||
label="Masked Image",
|
||||
source="upload",
|
||||
tool="sketch",
|
||||
sources="upload",
|
||||
type="pil",
|
||||
height=350,
|
||||
)
|
||||
@@ -406,6 +308,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -417,6 +320,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -424,6 +332,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -527,17 +436,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -549,14 +449,26 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{inpaint_model_info}\n"
|
||||
"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
inpaint_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
inpaint_sendto_outpaint = gr.Button(
|
||||
@@ -583,7 +495,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -613,3 +524,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
BIN
apps/stable_diffusion/web/ui/logos/nod-icon.png
Normal file
BIN
apps/stable_diffusion/web/ui/logos/nod-icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
@@ -50,6 +50,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
@@ -73,6 +74,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -105,6 +107,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
@@ -177,6 +180,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
|
||||
@@ -109,7 +109,7 @@ with gr.Blocks() as minigpt4_web:
|
||||
gr.Markdown(description)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
with gr.Column():
|
||||
image = gr.Image(type="pil")
|
||||
upload_button = gr.Button(
|
||||
value="Upload & Start Chat",
|
||||
@@ -143,6 +143,7 @@ with gr.Blocks() as minigpt4_web:
|
||||
# else "Only CUDA Supported for now",
|
||||
choices=["cuda"],
|
||||
interactive=False,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
|
||||
@@ -98,12 +98,12 @@ with gr.Blocks() as model_web:
|
||||
choices=None,
|
||||
value=None,
|
||||
visible=False,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
# TODO: select and SendTo
|
||||
civit_models = gr.Gallery(
|
||||
label="Civitai Model Gallery",
|
||||
value=None,
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,9 +3,8 @@ import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -53,8 +52,7 @@ def outpaint_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -88,21 +86,17 @@ def outpaint_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"):
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -127,7 +121,7 @@ def outpaint_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
@@ -233,88 +227,6 @@ def outpaint_inf(
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Inpaint Rest API.
|
||||
def outpaint_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = outpaint_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
init_image,
|
||||
InputData["pixels"],
|
||||
InputData["mask_blur"],
|
||||
InputData["directions"],
|
||||
InputData["noise_q"],
|
||||
InputData["color_variation"],
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -332,36 +244,22 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
outpaint_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
outpaint_model_info = (
|
||||
f"Custom Model Path: {outpaint_model_info}"
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
)
|
||||
outpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=outpaint_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files(
|
||||
choices=get_custom_model_files(
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
outpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
|
||||
"https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
outpaint_vae_info = (
|
||||
@@ -376,8 +274,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
@@ -411,6 +310,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -422,6 +322,11 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -429,6 +334,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -555,17 +461,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -577,13 +474,26 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{outpaint_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
outpaint_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -611,7 +521,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -641,3 +550,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -91,11 +91,11 @@ with gr.Blocks() as outputgallery_web:
|
||||
value=gallery_files.value,
|
||||
visible=False,
|
||||
show_label=True,
|
||||
columns=2,
|
||||
columns=4,
|
||||
)
|
||||
|
||||
with gr.Column(scale=4):
|
||||
with gr.Box():
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
with gr.Column(
|
||||
scale=15,
|
||||
@@ -109,6 +109,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
value="",
|
||||
interactive=True,
|
||||
elem_classes="dropdown_no_container",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column(
|
||||
scale=1,
|
||||
@@ -151,6 +152,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
wrap=True,
|
||||
elem_classes="output_parameters_dataframe",
|
||||
value=[["Status", "No image selected"]],
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Send To", open=True):
|
||||
@@ -161,6 +163,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
elem_classes="outputgallery_sendto",
|
||||
size="sm",
|
||||
)
|
||||
outputgallery_sendto_txt2img_sdxl = gr.Button(
|
||||
value="Txt2Img XL",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
size="sm",
|
||||
)
|
||||
|
||||
outputgallery_sendto_img2img = gr.Button(
|
||||
value="Img2Img",
|
||||
@@ -194,15 +202,18 @@ with gr.Blocks() as outputgallery_web:
|
||||
|
||||
def on_clear_gallery():
|
||||
return [
|
||||
gr.Gallery.update(
|
||||
gr.Gallery(
|
||||
value=[],
|
||||
visible=False,
|
||||
),
|
||||
gr.Image.update(
|
||||
gr.Image(
|
||||
visible=True,
|
||||
),
|
||||
]
|
||||
|
||||
def on_image_columns_change(columns):
|
||||
return gr.Gallery(columns=columns)
|
||||
|
||||
def on_select_subdir(subdir) -> list:
|
||||
# evt.value is the subdirectory name
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
@@ -211,12 +222,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
)
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
gr.Gallery(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
gr.Image(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
@@ -250,16 +261,16 @@ with gr.Blocks() as outputgallery_web:
|
||||
)
|
||||
|
||||
return [
|
||||
gr.Dropdown.update(
|
||||
gr.Dropdown(
|
||||
choices=refreshed_subdirs,
|
||||
value=new_subdir,
|
||||
),
|
||||
refreshed_subdirs,
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
gr.Gallery(
|
||||
value=new_images, label=new_label, visible=len(new_images) > 0
|
||||
),
|
||||
gr.Image.update(
|
||||
gr.Image(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
@@ -285,12 +296,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
gr.Gallery(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
gr.Image(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
@@ -328,12 +339,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
return [
|
||||
# disable or enable each of the sendto button based on whether
|
||||
# an image is selected
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
]
|
||||
|
||||
# The time first our tab is selected we need to do an initial refresh
|
||||
@@ -364,53 +375,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
# Unfortunately as of gradio 3.34.0 gr.update against Galleries doesn't
|
||||
# support things set with .style, nor the elem_classes kwarg, so we have
|
||||
# to directly set things up via JavaScript if we want the client to take
|
||||
# notice of our changes to the number of columns after it decides to put
|
||||
# them back to the original number when we change something
|
||||
def js_set_columns_in_browser(timeout_length):
|
||||
return f"""
|
||||
(new_cols) => {{
|
||||
setTimeout(() => {{
|
||||
required_style = "auto ".repeat(new_cols).trim();
|
||||
gallery = document.querySelector('#outputgallery_gallery .grid-container');
|
||||
if (gallery) {{
|
||||
gallery.style.gridTemplateColumns = required_style
|
||||
}}
|
||||
}}, {timeout_length});
|
||||
return []; // prevents console error from gradio
|
||||
}}
|
||||
"""
|
||||
|
||||
# --- Wire handlers up to the actions
|
||||
|
||||
# Many actions reset the number of columns shown in the gallery on the
|
||||
# browser end, so we have to set them back to what we think they should
|
||||
# be after the initial action.
|
||||
#
|
||||
# None of the actions on this tab trigger inference, and we want the
|
||||
# user to be able to do them whilst other tabs have ongoing inference
|
||||
# running. Waiting in the queue behind inference jobs would mean the UI
|
||||
# can't fully respond until the inference tasks complete,
|
||||
# hence queue=False on all of these.
|
||||
set_gallery_columns_immediate = dict(
|
||||
fn=None,
|
||||
inputs=[image_columns],
|
||||
# gradio blanks the UI on Chrome on Linux on gallery select if
|
||||
# I don't put an output here
|
||||
outputs=[dev_null],
|
||||
_js=js_set_columns_in_browser(0),
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# setting columns after selecting a gallery item needs a real
|
||||
# timeout length for the number of columns to actually be applied.
|
||||
# Not really sure why, maybe something has to finish animating?
|
||||
set_gallery_columns_delayed = dict(
|
||||
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
|
||||
)
|
||||
|
||||
# clearing images when we need to completely change what's in the
|
||||
# gallery avoids current images being shown replacing piecemeal and
|
||||
# prevents weirdness and errors if the user selects an image during the
|
||||
@@ -422,38 +386,42 @@ with gr.Blocks() as outputgallery_web:
|
||||
queue=False,
|
||||
)
|
||||
|
||||
image_columns.change(**set_gallery_columns_immediate)
|
||||
|
||||
subdirectories.select(**clear_gallery).then(
|
||||
on_select_subdir,
|
||||
[subdirectories],
|
||||
[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
)
|
||||
|
||||
open_subdir.click(
|
||||
on_open_subdir, inputs=[subdirectories], queue=False
|
||||
).then(**set_gallery_columns_immediate)
|
||||
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
|
||||
|
||||
refresh.click(**clear_gallery).then(
|
||||
on_refresh,
|
||||
[subdirectories],
|
||||
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
)
|
||||
|
||||
image_columns.change(
|
||||
fn=on_image_columns_change,
|
||||
inputs=[image_columns],
|
||||
outputs=[gallery],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
gallery.select(
|
||||
on_select_image,
|
||||
[gallery_files],
|
||||
[outputgallery_filename, image_parameters],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_delayed)
|
||||
)
|
||||
|
||||
outputgallery_filename.change(
|
||||
on_outputgallery_filename_change,
|
||||
[outputgallery_filename],
|
||||
[
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
@@ -476,7 +444,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
open_subdir,
|
||||
],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
)
|
||||
|
||||
# We should have been passed a list of components on other tabs that update
|
||||
# when a new image has generated on that tab, so set things up so the user
|
||||
@@ -488,4 +456,4 @@ with gr.Blocks() as outputgallery_web:
|
||||
inputs=[subdirectories, subdirectory_paths, component],
|
||||
outputs=[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
)
|
||||
|
||||
@@ -6,9 +6,10 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from shark.iree_utils.compile_utils import clean_device_info
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import time
|
||||
import sys
|
||||
|
||||
|
||||
def user(message, history):
|
||||
@@ -24,88 +25,81 @@ past_key_values = None
|
||||
|
||||
model_map = {
|
||||
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"llama2_13b": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"codegen": "Salesforce/codegen25-7b-multi",
|
||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"vicuna4": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
|
||||
}
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2_7b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_13b": (
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"StableLM": (
|
||||
"<|SYSTEM|># StableLM Tuned (Alpha version)"
|
||||
"\n- StableLM is a helpful and harmless open-source AI language model "
|
||||
"developed by StabilityAI."
|
||||
"\n- StableLM is excited to be able to help the user, but will refuse "
|
||||
"to do anything that could be considered harmful to the user."
|
||||
"\n- StableLM is more than just an information source, StableLM is also "
|
||||
"able to write poetry, short stories, and make jokes."
|
||||
"\n- StableLM will refuse to participate in anything that "
|
||||
"could harm a human."
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
"A chat between a curious user and an artificial intelligence "
|
||||
"assistant. The assistant gives helpful, detailed, and "
|
||||
"polite answers to the user's questions.\n"
|
||||
),
|
||||
"vicuna4": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna1p3": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"codegen": "",
|
||||
}
|
||||
|
||||
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
def create_prompt(model_name, history, prompt_prefix):
|
||||
system_message = ""
|
||||
if prompt_prefix:
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if model_name in [
|
||||
"StableLM",
|
||||
"vicuna",
|
||||
"vicuna4",
|
||||
"vicuna1p3",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
if "llama2" in model_name:
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
conversation = "".join(
|
||||
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
|
||||
)
|
||||
if prompt_prefix:
|
||||
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
else:
|
||||
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
elif model_name in ["vicuna"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
else:
|
||||
conversation = "".join(
|
||||
["".join([item[0], item[1]]) for item in history]
|
||||
)
|
||||
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
@@ -144,145 +138,150 @@ model_vmfb_key = ""
|
||||
|
||||
# TODO: Make chat reusable for UI and API
|
||||
def chat(
|
||||
curr_system_message,
|
||||
prompt_prefix,
|
||||
history,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
cli=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global past_key_values
|
||||
global model_vmfb_key
|
||||
|
||||
global vicuna_model
|
||||
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
device, device_id = clean_device_info(device)
|
||||
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{precision}"
|
||||
if model_name in [
|
||||
"vicuna",
|
||||
"vicuna4",
|
||||
"vicuna1p3",
|
||||
"codegen",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
if args.iree_vulkan_target_triple != "":
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
# else:
|
||||
# if config_file is not None:
|
||||
# config_file = open(config_file)
|
||||
# config_json = json.load(config_file)
|
||||
# config_file.close()
|
||||
# else:
|
||||
# config_json = get_default_config()
|
||||
# vicuna_model = ShardedVicuna(
|
||||
# model_name,
|
||||
# device=device,
|
||||
# precision=precision,
|
||||
# config_json=config_json,
|
||||
# )
|
||||
|
||||
prompt = create_prompt(model_name, history)
|
||||
|
||||
partial_text = ""
|
||||
count = 0
|
||||
start_time = time.time()
|
||||
for text, msg in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
count += 1
|
||||
if "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
end_time = time.time()
|
||||
tokens_per_sec = count / (end_time - start_time)
|
||||
yield history, str(
|
||||
format(tokens_per_sec, ".2f")
|
||||
) + " tokens/sec"
|
||||
else:
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, ""
|
||||
|
||||
return history, ""
|
||||
|
||||
# else Model is StableLM
|
||||
global sharkModel
|
||||
from apps.language_models.src.pipelines.stablelm_pipeline import (
|
||||
SharkStableLM,
|
||||
)
|
||||
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
|
||||
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
# max_new_tokens=512
|
||||
shark_slm = SharkStableLM(
|
||||
model_name
|
||||
) # pass elements from UI as required
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# Construct the input message string for the model by concatenating the
|
||||
# current system message and conversation history
|
||||
if len(curr_system_message.split()) > 160:
|
||||
print("clearing context")
|
||||
prompt = create_prompt(model_name, history)
|
||||
generate_kwargs = dict(prompt=prompt)
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
words_list = shark_slm.generate(**generate_kwargs)
|
||||
_extra_args = _extra_args + [
|
||||
"--iree-global-opt-enable-quantized-matmul-reassociation",
|
||||
"--iree-llvmcpu-enable-quantized-matmul-reassociation",
|
||||
"--iree-opt-const-eval=false",
|
||||
"--iree-opt-data-tiling=false",
|
||||
]
|
||||
|
||||
if device == "vulkan":
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if vulkan_target_triple == "":
|
||||
# We already have the device_id extracted via WebUI, so we directly use
|
||||
# that to find the target triple.
|
||||
vulkan_target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[device_id]
|
||||
)
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={vulkan_target_triple}"
|
||||
)
|
||||
if "rdna" in vulkan_target_triple:
|
||||
flags_to_add = [
|
||||
"--iree-spirv-index-bits=64",
|
||||
]
|
||||
_extra_args = _extra_args + flags_to_add
|
||||
|
||||
if device_id is None:
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[id]
|
||||
)
|
||||
if target_triple == vulkan_target_triple:
|
||||
device_id = id
|
||||
break
|
||||
id += 1
|
||||
|
||||
assert (
|
||||
device_id
|
||||
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
print(f"Will use vulkan target triple : {vulkan_target_triple}")
|
||||
|
||||
elif "rocm" in device:
|
||||
# add iree rocm flags
|
||||
if args.iree_rocm_target_chip != "":
|
||||
_extra_args.append(
|
||||
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
|
||||
)
|
||||
print(f"extra args = {_extra_args}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=download_vmfb,
|
||||
load_mlir_from_shark_tank=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
if vicuna_model is None:
|
||||
sys.exit("Unable to instantiate the model object, exiting.")
|
||||
|
||||
prompt = create_prompt(model_name, history, prompt_prefix)
|
||||
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to clean up the message textbox and the updated
|
||||
# conversation history
|
||||
yield history
|
||||
return words_list
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
# for text, msg, exec_time in progress.tqdm(
|
||||
# vicuna_model.generate(prompt, cli=cli),
|
||||
# desc="generating response",
|
||||
# ):
|
||||
for text, msg, exec_time in vicuna_model.generate(prompt, cli=cli):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, f"Prefill: {prefill_time:.2f}"
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
|
||||
return history, ""
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
@@ -318,17 +317,9 @@ def llm_chat_api(InputData: dict):
|
||||
UnshardedVicuna,
|
||||
)
|
||||
|
||||
device_id = None
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
device, device_id = clean_device_info(device)
|
||||
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
@@ -336,6 +327,9 @@ def llm_chat_api(InputData: dict):
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=True,
|
||||
load_mlir_from_shark_tank=True,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
# TODO: add role dict for different models
|
||||
@@ -398,47 +392,59 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
)
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value=model_choices[4],
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
supported_devices = available_devices
|
||||
enabled = len(supported_devices) > 0
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
# print(supported_devices)
|
||||
devices = gr.Dropdown(
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
# multiselect=True,
|
||||
allow_custom_value=True,
|
||||
# multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="int8",
|
||||
value="int4",
|
||||
choices=[
|
||||
"int4",
|
||||
"int8",
|
||||
"fp16",
|
||||
],
|
||||
visible=True,
|
||||
visible=False,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
prompt_prefix = gr.Checkbox(
|
||||
label="Add System Prompt",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(
|
||||
label="Upload sharding configuration", visible=False
|
||||
)
|
||||
json_view_button = gr.Button(label="View as JSON", visible=False)
|
||||
json_view = gr.JSON(interactive=True, visible=False)
|
||||
json_view_button = gr.Button(value="View as JSON", visible=False)
|
||||
json_view = gr.JSON(visible=False)
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
chatbot = gr.Chatbot(elem_id="chatbot")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
@@ -453,24 +459,47 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
system_msg = gr.Textbox(
|
||||
start_message, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
|
||||
649
apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py
Normal file
649
apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py
Normal file
@@ -0,0 +1,649 @@
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from math import ceil
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_sdxl_models,
|
||||
cancel_sd,
|
||||
set_model_default_configs,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImageSDXLPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
save_output_img,
|
||||
prompt_examples,
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_iree_metal_target_platform = args.iree_metal_target_platform
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
def txt2img_sdxl_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: str | int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
if precision != "fp16":
|
||||
print("currently we support fp16 for SDXL")
|
||||
precision = "fp16"
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.ondemand = ondemand
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files():
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae:
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
"txt2img_sdxl",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
stencils=None,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.iree_metal_target_platform = init_iree_metal_target_platform
|
||||
args.use_tuned = init_use_tuned
|
||||
args.import_mlir = init_import_mlir
|
||||
args.img_path = None
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
if global_obj.get_cfg_obj().ondemand:
|
||||
print("Running txt2img in memory efficient mode.")
|
||||
global_obj.set_sd_obj(
|
||||
Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=precision,
|
||||
max_length=max_length,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=global_obj.get_cfg_obj().ondemand,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
text_output = ""
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], seeds[current_batch])
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Text-to-Image-SDXL",
|
||||
current_batch + 1,
|
||||
batch_count,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
theme = gr.themes.Glass(
|
||||
primary_hue="slate",
|
||||
secondary_hue="gray",
|
||||
)
|
||||
|
||||
with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
t2i_sdxl_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
txt2img_sdxl_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
choices=predefined_sdxl_models
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="sdxl"
|
||||
),
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
t2i_sdxl_vae_info = (
|
||||
str(get_custom_model_path("vae"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_sdxl_vae_info = (
|
||||
f"VAE Path: {t2i_sdxl_vae_info}"
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"VAE Models",
|
||||
info=t2i_sdxl_vae_info,
|
||||
elem_id="custom_model",
|
||||
value="None",
|
||||
choices=[
|
||||
None,
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
]
|
||||
+ get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
txt2img_sdxl_png_info_img = gr.Image(
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
visible=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
txt2img_sdxl_autogen = gr.Checkbox(
|
||||
label="Auto-Generate Images",
|
||||
value=False,
|
||||
visible=False,
|
||||
)
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_sdxl_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_sdxl_lora_info = f"LoRA Path: {t2i_sdxl_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=t2i_sdxl_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=[
|
||||
"DDIM",
|
||||
"EulerAncestralDiscrete",
|
||||
"EulerDiscrete",
|
||||
"LCMScheduler",
|
||||
],
|
||||
allow_custom_value=False,
|
||||
visible=True,
|
||||
)
|
||||
with gr.Column():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
512,
|
||||
1024,
|
||||
value=1024,
|
||||
step=256,
|
||||
label="Height",
|
||||
visible=True,
|
||||
interactive=True,
|
||||
)
|
||||
width = gr.Slider(
|
||||
512,
|
||||
1024,
|
||||
value=1024,
|
||||
step=256,
|
||||
label="Width",
|
||||
visible=True,
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp16",
|
||||
choices=[
|
||||
"fp16",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=77,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="Guidance Scale",
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
inputs=prompt,
|
||||
cache_examples=False,
|
||||
elem_id="prompt_examples",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
txt2img_sdxl_gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
columns=[2],
|
||||
object_fit="scale_down",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{t2i_sdxl_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
txt2img_sdxl_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
txt2img_sdxl_sendto_img2img = gr.Button(
|
||||
value="Send To Img2Img",
|
||||
visible=False,
|
||||
)
|
||||
txt2img_sdxl_sendto_inpaint = gr.Button(
|
||||
value="Send To Inpaint",
|
||||
visible=False,
|
||||
)
|
||||
txt2img_sdxl_sendto_outpaint = gr.Button(
|
||||
value="Send To Outpaint",
|
||||
visible=False,
|
||||
)
|
||||
txt2img_sdxl_sendto_upscaler = gr.Button(
|
||||
value="Send To Upscaler",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=txt2img_sdxl_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
txt2img_sdxl_custom_model,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
queue=True,
|
||||
)
|
||||
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=txt2img_sdxl_status,
|
||||
concurrency_limit=1,
|
||||
)
|
||||
|
||||
def autogen_changed(checked):
|
||||
if checked:
|
||||
args.autogen = True
|
||||
else:
|
||||
args.autogen = False
|
||||
|
||||
def check_last_input(prompt):
|
||||
if not prompt.endswith(" "):
|
||||
return True
|
||||
elif not args.autogen:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
auto_gen_kwargs = dict(
|
||||
fn=check_last_input,
|
||||
inputs=[negative_prompt],
|
||||
outputs=[txt2img_sdxl_status],
|
||||
concurrency_limit=1,
|
||||
)
|
||||
|
||||
txt2img_sdxl_autogen.change(
|
||||
fn=autogen_changed,
|
||||
inputs=[txt2img_sdxl_autogen],
|
||||
outputs=None,
|
||||
)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[
|
||||
prompt_submit,
|
||||
neg_prompt_submit,
|
||||
generate_click,
|
||||
],
|
||||
)
|
||||
|
||||
txt2img_sdxl_png_info_img.change(
|
||||
fn=import_png_metadata,
|
||||
inputs=[
|
||||
txt2img_sdxl_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
txt2img_sdxl_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
outputs=[
|
||||
txt2img_sdxl_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
txt2img_sdxl_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
txt2img_sdxl_custom_model.change(
|
||||
fn=set_model_default_configs,
|
||||
inputs=[
|
||||
txt2img_sdxl_custom_model,
|
||||
],
|
||||
outputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
width,
|
||||
height,
|
||||
custom_vae,
|
||||
txt2img_sdxl_autogen,
|
||||
],
|
||||
)
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
@@ -4,18 +4,19 @@ import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from math import ceil
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
scheduler_list_cpu_only,
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
@@ -26,10 +27,12 @@ from apps.stable_diffusion.src import (
|
||||
utils,
|
||||
save_output_img,
|
||||
prompt_examples,
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
@@ -50,8 +53,7 @@ def txt2img_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -62,6 +64,11 @@ def txt2img_inf(
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
use_hiresfix: bool,
|
||||
hiresfix_height: int,
|
||||
hiresfix_width: int,
|
||||
hiresfix_strength: float,
|
||||
resample_type: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -84,21 +91,17 @@ def txt2img_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files():
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -123,7 +126,7 @@ def txt2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
@@ -138,6 +141,11 @@ def txt2img_inf(
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.use_hiresfix = use_hiresfix
|
||||
args.hiresfix_height = hiresfix_height
|
||||
args.hiresfix_width = hiresfix_width
|
||||
args.hiresfix_strength = hiresfix_strength
|
||||
args.resample_type = resample_type
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.iree_metal_target_platform = init_iree_metal_target_platform
|
||||
@@ -200,6 +208,82 @@ def txt2img_inf(
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
# TODO: allow user to save original image
|
||||
# TODO: add option to let user keep both pipelines loaded, and unload
|
||||
# either at will
|
||||
# TODO: add custom step value slider
|
||||
# TODO: add option to use secondary model for the img2img pass
|
||||
if use_hiresfix is True:
|
||||
new_config_obj = Config(
|
||||
"img2img",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
precision,
|
||||
1,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(args.scheduler)
|
||||
|
||||
global_obj.set_sd_obj(
|
||||
Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
1,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(args.scheduler)
|
||||
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
out_imgs[0],
|
||||
batch_size,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
ceil(steps / hiresfix_strength),
|
||||
hiresfix_strength,
|
||||
guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
stencils=[],
|
||||
control_mode=None,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
@@ -219,71 +303,17 @@ def txt2img_inf(
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Text2Img Rest API.
|
||||
def txt2img_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}.'
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
res = txt2img_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
@@ -302,32 +332,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
t2i_model_info = (
|
||||
f"Custom Model Path: {t2i_model_info}"
|
||||
)
|
||||
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
txt2img_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=t2i_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
choices=get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
txt2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the dropdown "
|
||||
"on the left and enter model ID here.",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL.",
|
||||
lines=3,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
t2i_vae_info = (
|
||||
@@ -343,13 +359,14 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
txt2img_png_info_img = gr.Image(
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
tool="None",
|
||||
visible=True,
|
||||
)
|
||||
|
||||
@@ -360,6 +377,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
# TODO: coming soon
|
||||
autogen = gr.Checkbox(
|
||||
label="Continuous Generation",
|
||||
visible=False,
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
@@ -379,6 +401,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -390,6 +413,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -397,6 +425,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -483,6 +512,41 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Accordion(label="Hires Fix Options", open=False):
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
use_hiresfix = gr.Checkbox(
|
||||
value=args.use_hiresfix,
|
||||
label="Use Hires Fix",
|
||||
interactive=True,
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=False,
|
||||
)
|
||||
hiresfix_height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_height,
|
||||
step=8,
|
||||
label="Hires Fix Height",
|
||||
)
|
||||
hiresfix_width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_width,
|
||||
step=8,
|
||||
label="Hires Fix Width",
|
||||
)
|
||||
hiresfix_strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.hiresfix_strength,
|
||||
step=0.01,
|
||||
label="Hires Fix Denoising Strength",
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
@@ -494,17 +558,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
@@ -523,13 +578,26 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{t2i_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
txt2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -554,7 +622,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -565,6 +632,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
use_hiresfix,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
hiresfix_strength,
|
||||
resample_type,
|
||||
],
|
||||
outputs=[txt2img_gallery, std_output, txt2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
@@ -599,7 +671,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
width,
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
@@ -615,9 +686,35 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
width,
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
|
||||
# SharkEulerDiscrete doesn't work with img2img which hires_fix uses
|
||||
def set_compatible_schedulers(hires_fix_selected):
|
||||
if hires_fix_selected:
|
||||
return gr.Dropdown(
|
||||
choices=scheduler_list_cpu_only,
|
||||
value="DEISMultistep",
|
||||
)
|
||||
else:
|
||||
return gr.Dropdown(
|
||||
choices=scheduler_list,
|
||||
value="SharkEulerDiscrete",
|
||||
)
|
||||
|
||||
use_hiresfix.change(
|
||||
fn=set_compatible_schedulers,
|
||||
inputs=[use_hiresfix],
|
||||
outputs=[scheduler],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -3,9 +3,7 @@ import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -15,6 +13,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_upscaler_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
@@ -46,8 +45,7 @@ def upscaler_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -85,21 +83,17 @@ def upscaler_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files(custom_checkpoint_type="upscaler"):
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -126,7 +120,7 @@ def upscaler_inf(
|
||||
args.width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
@@ -252,83 +246,6 @@ def upscaler_inf(
|
||||
yield generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Upscaler Rest API.
|
||||
def upscaler_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = upscaler_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
init_image,
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["noise_level"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -346,36 +263,22 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
upscaler_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
upscaler_model_info = (
|
||||
f"Custom Model Path: {upscaler_model_info}"
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
)
|
||||
upscaler_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=upscaler_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-x4-upscaler",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files(
|
||||
choices=get_custom_model_files(
|
||||
custom_checkpoint_type="upscaler"
|
||||
)
|
||||
+ predefined_upscaler_models,
|
||||
)
|
||||
upscaler_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: SG161222/Realistic_Vision_V1.3, "
|
||||
"https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
upscaler_vae_info = (
|
||||
@@ -390,6 +293,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -425,6 +330,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -436,6 +342,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -443,6 +354,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Scheduler",
|
||||
value="DDIM",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -547,17 +459,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -569,14 +472,26 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{upscaler_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
upscaler_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -600,7 +515,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -630,3 +544,10 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import os
|
||||
import sys
|
||||
from apps.stable_diffusion.src import get_available_devices
|
||||
import glob
|
||||
import math
|
||||
import json
|
||||
import safetensors
|
||||
import gradio as gr
|
||||
|
||||
from pathlib import Path
|
||||
from apps.stable_diffusion.src import args
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
|
||||
from apps.stable_diffusion.src import get_available_devices
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
@@ -24,8 +31,17 @@ class Config:
|
||||
width: int
|
||||
device: str
|
||||
use_lora: str
|
||||
use_stencil: str
|
||||
ondemand: str
|
||||
stencils: list[str]
|
||||
ondemand: str # should this be expecting a bool instead?
|
||||
|
||||
|
||||
class HSLHue(IntEnum):
|
||||
RED = 0
|
||||
YELLOW = 60
|
||||
GREEN = 120
|
||||
CYAN = 180
|
||||
BLUE = 240
|
||||
MAGENTA = 300
|
||||
|
||||
|
||||
custom_model_filetypes = (
|
||||
@@ -49,9 +65,11 @@ scheduler_list_cpu_only = [
|
||||
"DPMSolverSinglestep",
|
||||
"DDPM",
|
||||
"HeunDiscrete",
|
||||
"LCMScheduler",
|
||||
]
|
||||
scheduler_list = scheduler_list_cpu_only + [
|
||||
"SharkEulerDiscrete",
|
||||
"SharkEulerAncestralDiscrete",
|
||||
]
|
||||
|
||||
predefined_models = [
|
||||
@@ -72,6 +90,10 @@ predefined_paint_models = [
|
||||
predefined_upscaler_models = [
|
||||
"stabilityai/stable-diffusion-x4-upscaler",
|
||||
]
|
||||
predefined_sdxl_models = [
|
||||
"stabilityai/sdxl-turbo",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
]
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
@@ -125,6 +147,12 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
|
||||
)
|
||||
]
|
||||
match custom_checkpoint_type:
|
||||
case "sdxl":
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if any(x in val for x in ["XL", "xl", "Xl"])
|
||||
]
|
||||
case "inpainting":
|
||||
files = [
|
||||
val
|
||||
@@ -161,6 +189,69 @@ def get_custom_vae_or_lora_weights(weights, hf_id, model):
|
||||
return use_weight
|
||||
|
||||
|
||||
def hsl_color(alpha: float, start, end):
|
||||
b = (end - start) * (alpha if alpha > 0 else 0)
|
||||
result = b + start
|
||||
|
||||
# Return a CSS HSL string
|
||||
return f"hsl({math.floor(result)}, 80%, 35%)"
|
||||
|
||||
|
||||
def get_lora_metadata(lora_filename):
|
||||
# get the metadata from the file
|
||||
filename = get_custom_model_pathfile(lora_filename, "lora")
|
||||
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
|
||||
# guard clause for if there isn't any metadata
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
# metadata is a dictionary of strings, the values of the keys we're
|
||||
# interested in are actually json, and need to be loaded as such
|
||||
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
|
||||
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
|
||||
tag_dirs = [dir for dir in tag_frequencies.keys()]
|
||||
|
||||
# gather the tag frequency information for all the datasets trained
|
||||
all_frequencies = {}
|
||||
for dataset in tag_dirs:
|
||||
frequencies = sorted(
|
||||
[entry for entry in tag_frequencies[dataset].items()],
|
||||
reverse=True,
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
|
||||
# get a figure for the total number of images processed for this dataset
|
||||
# either then number actually listed or in its dataset_dir entry or
|
||||
# the highest frequency's number if that doesn't exist
|
||||
img_count = dataset_dirs.get(dir, {}).get(
|
||||
"img_count", frequencies[0][1]
|
||||
)
|
||||
|
||||
# add the dataset frequencies to the overall frequencies replacing the
|
||||
# frequency counts on the tags with a percentage/ratio
|
||||
all_frequencies.update(
|
||||
[(entry[0], entry[1] / img_count) for entry in frequencies]
|
||||
)
|
||||
|
||||
trained_model_id = " ".join(
|
||||
[
|
||||
metadata.get("ss_sd_model_hash", ""),
|
||||
metadata.get("ss_sd_model_name", ""),
|
||||
metadata.get("ss_base_model_version", ""),
|
||||
]
|
||||
).strip()
|
||||
|
||||
# return the topmost <count> of all frequencies in all datasets
|
||||
return {
|
||||
"model": trained_model_id,
|
||||
"frequencies": sorted(
|
||||
all_frequencies.items(), reverse=True, key=lambda x: x[1]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
# Try catch it, as gc can delete global_obj.sd_obj while switching model
|
||||
try:
|
||||
@@ -169,5 +260,99 @@ def cancel_sd():
|
||||
pass
|
||||
|
||||
|
||||
def set_model_default_configs(model_ckpt_or_id, jsonconfig=None):
|
||||
import gradio as gr
|
||||
|
||||
config_modelname = default_config_exists(model_ckpt_or_id)
|
||||
if jsonconfig:
|
||||
return get_config_from_json(jsonconfig)
|
||||
elif config_modelname:
|
||||
return default_configs[config_modelname]
|
||||
# TODO: Use HF metadata to setup pipeline if available
|
||||
# elif is_valid_hf_id(model_ckpt_or_id):
|
||||
# return get_HF_default_configs(model_ckpt_or_id)
|
||||
else:
|
||||
# We don't have default metadata to setup a good config. Do not change configs.
|
||||
return [
|
||||
gr.Textbox(label="Prompt", interactive=True, visible=True),
|
||||
gr.Textbox(label="Negative Prompt", interactive=True),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.Checkbox(
|
||||
label="Auto-Generate",
|
||||
visible=False,
|
||||
interactive=False,
|
||||
value=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_config_from_json(model_ckpt_or_id, jsonconfig):
|
||||
# TODO: make this work properly. It is currently not user-exposed.
|
||||
cfgdata = json.load(jsonconfig)
|
||||
return [
|
||||
cfgdata["prompt_box_behavior"],
|
||||
cfgdata["neg_prompt_box_behavior"],
|
||||
cfgdata["steps"],
|
||||
cfgdata["scheduler"],
|
||||
cfgdata["guidance_scale"],
|
||||
cfgdata["width"],
|
||||
cfgdata["height"],
|
||||
cfgdata["custom_vae"],
|
||||
]
|
||||
|
||||
|
||||
def default_config_exists(model_ckpt_or_id):
|
||||
if model_ckpt_or_id in [
|
||||
"stabilityai/sdxl-turbo",
|
||||
"stabilityai/stable_diffusion-xl-base-1.0",
|
||||
]:
|
||||
return model_ckpt_or_id
|
||||
elif "turbo" in model_ckpt_or_id.lower():
|
||||
return "stabilityai/sdxl-turbo"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
default_configs = {
|
||||
"stabilityai/sdxl-turbo": [
|
||||
gr.Textbox(label="", interactive=False, value=None, visible=False),
|
||||
gr.Textbox(
|
||||
label="Prompt",
|
||||
value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette",
|
||||
),
|
||||
gr.Slider(0, 10, value=2),
|
||||
gr.Dropdown(value="EulerAncestralDiscrete"),
|
||||
gr.Slider(0, value=0),
|
||||
512,
|
||||
512,
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
gr.Checkbox(
|
||||
label="Auto-Generate", visible=False, interactive=True, value=False
|
||||
),
|
||||
],
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": [
|
||||
gr.Textbox(label="Prompt", interactive=True, visible=True),
|
||||
gr.Textbox(label="Negative Prompt", interactive=True),
|
||||
40,
|
||||
"EulerDiscrete",
|
||||
7.5,
|
||||
gr.Slider(value=768, interactive=True),
|
||||
gr.Slider(value=768, interactive=True),
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
gr.Checkbox(
|
||||
label="Auto-Generate",
|
||||
visible=False,
|
||||
interactive=False,
|
||||
value=False,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
nodicon_loc = resource_path("logos/nod-icon.png")
|
||||
available_devices = get_available_devices()
|
||||
|
||||
105
apps/stable_diffusion/web/utils/app.py
Normal file
105
apps/stable_diffusion/web/utils/app.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import os
|
||||
import sys
|
||||
import webview
|
||||
import webview.util
|
||||
import socket
|
||||
|
||||
from contextlib import closing
|
||||
from multiprocessing import Process
|
||||
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
|
||||
def webview2_installed():
|
||||
if sys.platform != "win32":
|
||||
return False
|
||||
|
||||
# On windows we want to ensure we have MS webview2 available so we don't fall back
|
||||
# to MSHTML (aka ye olde Internet Explorer) which is deprecated by pywebview, and
|
||||
# apparently causes SHARK not to load in properly.
|
||||
|
||||
# Checking these registry entries is how Microsoft says to detect a webview2 installation:
|
||||
# https://learn.microsoft.com/en-us/microsoft-edge/webview2/concepts/distribution
|
||||
import winreg
|
||||
|
||||
path = r"SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
|
||||
|
||||
# only way can find if a registry entry even exists is to try and open it
|
||||
try:
|
||||
# check for an all user install
|
||||
with winreg.OpenKey(
|
||||
winreg.HKEY_LOCAL_MACHINE,
|
||||
path,
|
||||
0,
|
||||
winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY,
|
||||
) as registry_key:
|
||||
value, type = winreg.QueryValueEx(registry_key, "pv")
|
||||
|
||||
# if it didn't exist, we want to continue on...
|
||||
except WindowsError:
|
||||
try:
|
||||
# ...to check for a current user install
|
||||
with winreg.OpenKey(
|
||||
winreg.HKEY_CURRENT_USER,
|
||||
path,
|
||||
0,
|
||||
winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY,
|
||||
) as registry_key:
|
||||
value, type = winreg.QueryValueEx(registry_key, "pv")
|
||||
except WindowsError:
|
||||
value = None
|
||||
finally:
|
||||
return (value is not None) and value != "" and value != "0.0.0.0"
|
||||
|
||||
|
||||
def window(address):
|
||||
from tkinter import Tk
|
||||
|
||||
window = Tk()
|
||||
|
||||
# get screen width and height of display and make it more reasonably
|
||||
# sized as we aren't making it full-screen or maximized
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
def usable_port():
|
||||
# Make sure we can actually use the port given in args.server_port. If
|
||||
# not ask the OS for a port and return that as our port to use.
|
||||
|
||||
port = args.server_port
|
||||
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||
try:
|
||||
sock.bind(("0.0.0.0", port))
|
||||
except OSError:
|
||||
with closing(
|
||||
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
) as sock:
|
||||
sock.bind(("0.0.0.0", 0))
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
return sock.getsockname()[1]
|
||||
|
||||
return port
|
||||
|
||||
|
||||
def launch(port):
|
||||
# setup to launch as an app if app mode has been requested and we're able
|
||||
# to do it, answering whether we succeeded.
|
||||
if args.ui == "app" and (sys.platform != "win32" or webview2_installed()):
|
||||
try:
|
||||
t = Process(target=window, args=[f"http://localhost:{port}"])
|
||||
t.start()
|
||||
return True
|
||||
except webview.util.WebViewException:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
@@ -149,7 +149,6 @@ def import_png_metadata(
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
@@ -175,10 +174,8 @@ def import_png_metadata(
|
||||
|
||||
if "Model" in metadata and png_custom_model:
|
||||
custom_model = png_custom_model
|
||||
hf_model_id = ""
|
||||
if "Model" in metadata and png_hf_model_id:
|
||||
custom_model = "None"
|
||||
hf_model_id = png_hf_model_id
|
||||
elif "Model" in metadata and png_hf_model_id:
|
||||
custom_model = png_hf_model_id
|
||||
|
||||
if "LoRA" in metadata and lora_custom_model:
|
||||
custom_lora = lora_custom_model
|
||||
@@ -217,7 +214,6 @@ def import_png_metadata(
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
|
||||
@@ -5,11 +5,25 @@ from time import time
|
||||
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
|
||||
|
||||
def config_gradio_tmp_imgs_folder():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
def clear_tmp_mlir():
|
||||
cleanup_start = time()
|
||||
print(
|
||||
"Clearing .mlir temporary files from a prior run. This may take some time..."
|
||||
)
|
||||
mlir_files = [
|
||||
filename
|
||||
for filename in os.listdir(shark_tmp)
|
||||
if os.path.isfile(os.path.join(shark_tmp, filename))
|
||||
and filename.endswith(".mlir")
|
||||
]
|
||||
for filename in mlir_files:
|
||||
os.remove(shark_tmp + filename)
|
||||
print(
|
||||
f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
|
||||
|
||||
def clear_tmp_imgs():
|
||||
# tell gradio to use a directory under shark_tmp for its temporary
|
||||
# image files unless somewhere else has been set
|
||||
if "GRADIO_TEMP_DIR" not in os.environ:
|
||||
@@ -52,3 +66,12 @@ def config_gradio_tmp_imgs_folder():
|
||||
)
|
||||
else:
|
||||
print("No temporary images files to clear.")
|
||||
|
||||
|
||||
def config_tmp():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
|
||||
clear_tmp_mlir()
|
||||
clear_tmp_imgs()
|
||||
@@ -129,12 +129,12 @@ pytest_benchmark_param = pytest.mark.parametrize(
|
||||
pytest.param(True, "cpu", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"gpu",
|
||||
"cuda",
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("cuda"), reason="nvidia-smi not found"
|
||||
),
|
||||
),
|
||||
pytest.param(True, "gpu", marks=pytest.mark.skip),
|
||||
pytest.param(True, "cuda", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"vulkan",
|
||||
|
||||
@@ -78,7 +78,10 @@ def test_loop(
|
||||
os.mkdir("./test_images/golden")
|
||||
get_inpaint_inputs()
|
||||
hf_model_names = model_config_dicts[0].values()
|
||||
tuned_options = ["--no-use_tuned", "--use_tuned"]
|
||||
tuned_options = [
|
||||
"--no-use_tuned",
|
||||
"--use_tuned",
|
||||
]
|
||||
import_options = ["--import_mlir", "--no-import_mlir"]
|
||||
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
|
||||
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
@@ -112,6 +115,8 @@ def test_loop(
|
||||
and use_tune == tuned_options[1]
|
||||
):
|
||||
continue
|
||||
elif use_tune == tuned_options[1]:
|
||||
continue
|
||||
command = (
|
||||
[
|
||||
executable, # executable is the python from the venv used to run this
|
||||
|
||||
@@ -40,7 +40,7 @@ cmake --build build/
|
||||
*Prepare the model*
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
```
|
||||
*Prepare the input*
|
||||
|
||||
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
|
||||
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
|
||||
```
|
||||
VAE and Autoencoder are also available
|
||||
```bash
|
||||
# VAE
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux vae.mlir -o vae.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
|
||||
|
||||
# CLIP Autoencoder
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
|
||||
```
|
||||
|
||||
@@ -55,7 +55,7 @@ The command line for compilation will start something like this, where the `-` n
|
||||
The `-o output_filename.vmfb` flag can be used to specify the location to save the compiled vmfb. Note that a dump of the
|
||||
dispatches that can be compiled + run in isolation can be generated by adding `--iree-hal-dump-executable-benchmarks-to=/some/directory`. Say, if they are in the `benchmarks` directory, the following compile/run commands would work for Vulkan on RDNA3.
|
||||
```
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
|
||||
|
||||
iree-benchmark-module --module=benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb --function=forward --device=vulkan
|
||||
```
|
||||
|
||||
140
docs/shark_sd_koboldcpp.md
Normal file
140
docs/shark_sd_koboldcpp.md
Normal file
@@ -0,0 +1,140 @@
|
||||
# Overview
|
||||
|
||||
In [1.47.2](https://github.com/LostRuins/koboldcpp/releases/tag/v1.47.2) [Koboldcpp](https://github.com/LostRuins/koboldcpp) added AUTOMATIC1111 integration for image generation. Since SHARK implements a small subset of the A1111 REST api, you can also use SHARK for this. This document gives a starting point for how to get this working.
|
||||
|
||||
## In Action
|
||||
|
||||

|
||||
|
||||
## Memory considerations
|
||||
|
||||
Since both Koboldcpp and SHARK will use VRAM on your graphic card(s) running both at the same time using the same card will impose extra limitations on the model size you can fully offload to the video card in Koboldcpp. For me, on a RX 7900 XTX on Windows with 24 GiB of VRAM, the limit was about a 13 Billion parameter model with Q5_K_M quantisation.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
When using SHARK for image generation, especially with Koboldcpp, you need to be aware that it is currently designed to pay a large upfront cost in time compiling and tuning the model you select, to get an optimal individual image generation time. You need to be the judge as to whether this trade-off is going to be worth it for your OS and hardware combination.
|
||||
|
||||
It means that the first time you run a particular Stable Diffusion model for a particular combination of image size, LoRA, and VAE, SHARK will spend *many minutes* - even on a beefy machaine with very fast graphics card with lots of memory - building that model combination just so it can save it to disk. It may even have to go away and download the model if it doesn't already have it locally. Once it has done its build of a model combination for your hardware once, it shouldn't need to do it again until you upgrade to a newer SHARK version, install different drivers or change your graphics hardware. It will just upload the files it generated the first time to your graphics card and proceed from there.
|
||||
|
||||
This does mean however, that on a brand new fresh install of SHARK that has not generated any images on a model you haven't selected before, the first image Koboldcpp requests may look like it is *never* going finish and that the whole process has broken. Be forewarned, make yourself a cup of coffee, and expect a lot of messages about compilation and tuning from SHARK in the terminal you ran it from.
|
||||
|
||||
## Setup SHARK and prerequisites:
|
||||
|
||||
* Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme).
|
||||
* Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install.
|
||||
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_accept_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
|
||||
|
||||
```powershell
|
||||
## Run the .exe in API mode, with CORS support, on the A1111 endpoint port:
|
||||
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_accept_origin="*" --server_port=7860
|
||||
|
||||
## Run trom the base directory of a source clone of SHARK on Windows:
|
||||
.\setup_venv.ps1
|
||||
python .\apps\stable_diffusion\web\index.py --api --api_accept_origin="*" --server_port=7860
|
||||
|
||||
## Run a the base directory of a source clone of SHARK on Linux:
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
python ./apps/stable_diffusion/web/index.py --api --api_accept_origin="*" --server_port=7860
|
||||
|
||||
## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860
|
||||
|
||||
## Since the api respects most applicable SHARK command line arguments for options not specified,
|
||||
## or currently unimplemented by API, there might be some you want to set, as listed in `--help`
|
||||
.\node_ai_shark_studio_20320901_2525.exe --help
|
||||
|
||||
## For instance, the example above, but with a a custom VAE specified
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
|
||||
|
||||
## An example with multiple specific CORS origins
|
||||
python apps/stable_diffusion/web/index.py --api --api_accept_origin="koboldcpp.example.com:7001" --api_accept_origin="koboldcpp.example.com:7002" --server_port=7860
|
||||
```
|
||||
|
||||
SHARK should start in server mode, and you should see something like this:
|
||||
|
||||

|
||||
|
||||
* Note: When running in api mode with `--api`, the .exe will not function as a webUI. Thus, the address or port shown in the terminal output will only be useful for API requests.
|
||||
|
||||
|
||||
## Configure Koboldcpp for local image generation:
|
||||
|
||||
* Get the latest [Koboldcpp](https://github.com/LostRuins/koboldcpp/releases) if you don't already have it. If you have a recent AMD card that has ROCm HIP [support for Windows](https://rocmdocs.amd.com/en/latest/release/windows_support.html#windows-supported-gpus) or [support for Linux](https://rocmdocs.amd.com/en/latest/release/gpu_os_support.html#linux-supported-gpus), you'll likely prefer [YellowRosecx's ROCm fork](https://github.com/YellowRoseCx/koboldcpp-rocm).
|
||||
* Start Koboldcpp in another terminal/Powershell and setup your model configuration. Refer to the [Koboldcpp README](https://github.com/YellowRoseCx/koboldcpp-rocm) for more details on how to do this if this is your first time using Koboldcpp.
|
||||
* Once the main UI has loaded into your browser click the settings button, go to the advanced tab, and then choose *Local A1111* from the generate images dropdown:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
*if you get an error here, see the next section [below](#connecting-to-shark-on-a-different-address-or-port)*
|
||||
|
||||
* A list of Stable Diffusion models available to your SHARK instance should now be listed in the box below *generate images*. The default value will usually be set to `stabilityai/stable-diffusion-2-1-base`. Choose the model you want to use for image generation from the list (but see [performance considerations](#performance-considerations)).
|
||||
* You should now be ready to generate images, either by clicking the 'Add Img' button above the text entry box:
|
||||
|
||||

|
||||
|
||||
...or by selecting the 'Autogenerate' option in the settings:
|
||||
|
||||

|
||||
|
||||
*I often find that even if I have selected autogenerate I have to do an 'add img' to get things started off*
|
||||
|
||||
* There is one final piece of image generation configuration within Koboldcpp you might want to do. This is also in the generate images section of advanced settings. Here there is, not very obviously, a 'style' button:
|
||||
|
||||

|
||||
|
||||
This will bring up a dialog box where you can enter a short text that will sent as a prefix to the Prompt sent to SHARK:
|
||||
|
||||

|
||||
|
||||
|
||||
## Connecting to SHARK on a different address or port
|
||||
|
||||
If you didn't set the port to `--server_port=7860` when starting SHARK, or you are running it on different machine on your network than you are running Koboldcpp, or to where you are running the koboldcpp's kdlite client frontend, then you very likely got the following error:
|
||||
|
||||

|
||||
|
||||
As long as SHARK is running correctly, this means you need to set the url and port to the correct values in Koboldcpp. For instance. to set the port that Koboldcpp looks for an image generator to SHARK's default port of 8080:
|
||||
|
||||
* Select the cog icon the Generate Images section of Advanced settings:
|
||||
|
||||

|
||||
|
||||
* Then edit the port number at the end of the url in the 'A1111 Endpoint Selection' dialog box to read 8080:
|
||||
|
||||

|
||||
|
||||
* Similarly, when running SHARK on a different machine you will need to change host part of the endpoint url to the hostname or ip address where SHARK is running, similarly:
|
||||
|
||||

|
||||
|
||||
## Examples
|
||||
|
||||
Here's how Koboldcpp shows an image being requested:
|
||||
|
||||

|
||||
|
||||
The generated image in context in story mode:
|
||||
|
||||

|
||||
|
||||
And the same image when clicked on:
|
||||
|
||||

|
||||
|
||||
|
||||
## Where to find the images in SHARK
|
||||
|
||||
Even though Koboldcpp requests images at a size of 512x512, it resizes then to 256x256, converts them to `.jpeg`, and only shows them at 200x200 in the main text window. It does this so it can save them compactly embedded in your story as a `data://` uri.
|
||||
|
||||
However the images at the original size are saved by SHARK in its `output_dir` which is usually a folder named for the current date. inside `generated_imgs` folder in the SHARK installation directory.
|
||||
|
||||
You can browse these, either using the Output Gallery tab from within the SHARK web ui:
|
||||
|
||||

|
||||
|
||||
...or by browsing to the `output_dir` in your operating system's file manager:
|
||||
|
||||

|
||||
@@ -1,192 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
|
||||
project(sharkbackend LANGUAGES C CXX)
|
||||
|
||||
#
|
||||
# Options
|
||||
#
|
||||
|
||||
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
|
||||
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
|
||||
|
||||
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
|
||||
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
|
||||
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Dependencies
|
||||
#
|
||||
# FetchContent requires us to include the transitive closure of all
|
||||
# repos that we depend on so that we can override the tags.
|
||||
#
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-common
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
|
||||
GIT_TAG ${TRITON_COMMON_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-core
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
|
||||
GIT_TAG ${TRITON_CORE_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-backend
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
|
||||
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
|
||||
|
||||
#
|
||||
# The backend must be built into a shared library. Use an ldscript to
|
||||
# hide all symbols except for the TRITONBACKEND API.
|
||||
#
|
||||
configure_file(src/libtriton_dshark.ldscript libtriton_dshark.ldscript COPYONLY)
|
||||
|
||||
add_library(
|
||||
triton-dshark-backend SHARED
|
||||
src/dshark.cc
|
||||
#src/dshark_driver_module.c
|
||||
)
|
||||
|
||||
add_library(
|
||||
SharkBackend::triton-dshark-backend ALIAS triton-dshark-backend
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
triton-dshark-backend
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
|
||||
|
||||
add_subdirectory(thirdparty/srt EXCLUDE_FROM_ALL)
|
||||
|
||||
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
|
||||
iree_hal_hal
|
||||
iree_hal_cuda_cuda
|
||||
iree_hal_cuda_registration_registration
|
||||
iree_hal_vmvx_registration_registration
|
||||
iree_hal_dylib_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_hal_local_loaders_system_library_loader
|
||||
iree_hal_local_loaders_vmvx_module_loader
|
||||
)
|
||||
|
||||
target_compile_features(triton-dshark-backend PRIVATE cxx_std_11)
|
||||
|
||||
|
||||
target_link_libraries(
|
||||
triton-dshark-backend
|
||||
PRIVATE
|
||||
triton-core-serverapi # from repo-core
|
||||
triton-core-backendapi # from repo-core
|
||||
triton-core-serverstub # from repo-core
|
||||
triton-backend-utils # from repo-backend
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
set_target_properties(
|
||||
triton-dshark-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_dshark
|
||||
)
|
||||
else()
|
||||
set_target_properties(
|
||||
triton-dshark-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_dshark
|
||||
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_dshark.ldscript
|
||||
LINK_FLAGS "-Wl,--version-script libtriton_dshark.ldscript"
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
#
|
||||
# Install
|
||||
#
|
||||
include(GNUInstallDirs)
|
||||
set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/SharkBackend)
|
||||
|
||||
install(
|
||||
TARGETS
|
||||
triton-dshark-backend
|
||||
EXPORT
|
||||
triton-dshark-backend-targets
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
|
||||
)
|
||||
|
||||
install(
|
||||
EXPORT
|
||||
triton-dshark-backend-targets
|
||||
FILE
|
||||
SharkBackendTargets.cmake
|
||||
NAMESPACE
|
||||
SharkBackend::
|
||||
DESTINATION
|
||||
${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
include(CMakePackageConfigHelpers)
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_LIST_DIR}/cmake/SharkBackendConfig.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
|
||||
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
install(
|
||||
FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
|
||||
DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
#
|
||||
# Export from build tree
|
||||
#
|
||||
export(
|
||||
EXPORT triton-dshark-backend-targets
|
||||
FILE ${CMAKE_CURRENT_BINARY_DIR}/SharkBackendTargets.cmake
|
||||
NAMESPACE SharkBackend::
|
||||
)
|
||||
|
||||
export(PACKAGE SharkBackend)
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
# SHARK Triton Backend
|
||||
|
||||
The triton backend for shark.
|
||||
|
||||
# Build
|
||||
|
||||
Install SHARK
|
||||
|
||||
```
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
# skip above step if dshark is already installed
|
||||
cd SHARK/inference
|
||||
```
|
||||
|
||||
install dependancies
|
||||
|
||||
```
|
||||
apt-get install patchelf rapidjson-dev python3-dev
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
update the submodules of iree
|
||||
|
||||
```
|
||||
cd thirdparty/srt
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
Next, make the backend and install it
|
||||
|
||||
```
|
||||
cd ../..
|
||||
mkdir build && cd build
|
||||
cmake -DTRITON_ENABLE_GPU=ON \
|
||||
-DIREE_HAL_DRIVER_CUDA=ON \
|
||||
-DIREE_TARGET_BACKEND_CUDA=ON \
|
||||
-DMLIR_ENABLE_CUDA_RUNNER=ON \
|
||||
-DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \
|
||||
-DTRITON_BACKEND_REPO_TAG=r22.02 \
|
||||
-DTRITON_CORE_REPO_TAG=r22.02 \
|
||||
-DTRITON_COMMON_REPO_TAG=r22.02 ..
|
||||
make install
|
||||
```
|
||||
|
||||
# Incorporating into Triton
|
||||
|
||||
There are much more in depth explenations for the following steps in triton's documentation:
|
||||
https://github.com/triton-inference-server/server/blob/main/docs/compose.md#triton-with-unsupported-and-custom-backends
|
||||
|
||||
There should be a file at /build/install/backends/dshark/libtriton_dshark.so. You will need to copy it into your triton server image.
|
||||
More documentation is in the link above, but to create the docker image, you need to run the compose.py command in the triton-backend server repo
|
||||
|
||||
|
||||
To first build your image, clone the tritonserver repo.
|
||||
|
||||
```
|
||||
git clone https://github.com/triton-inference-server/server.git
|
||||
```
|
||||
|
||||
then run `compose.py` to build a docker compose file
|
||||
```
|
||||
cd server
|
||||
python3 compose.py --repoagent checksum --dry-run
|
||||
```
|
||||
|
||||
Because dshark is a third party backend, you will need to manually modify the `Dockerfile.compose` to include the dshark backend. To do this, in the Dockerfile.compose file produced, copy this line.
|
||||
the dshark backend will be located in the build folder from earlier under `/build/install/backends`
|
||||
|
||||
```
|
||||
COPY /path/to/build/install/backends/dshark /opt/tritonserver/backends/dshark
|
||||
```
|
||||
|
||||
Next run
|
||||
```
|
||||
docker build -t tritonserver_custom -f Dockerfile.compose .
|
||||
docker run -it --gpus=1 --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
|
||||
```
|
||||
|
||||
where `path/to/model_repos` is where you are storing the models you want to run
|
||||
|
||||
if your not using gpus, omit `--gpus=1`
|
||||
|
||||
```
|
||||
docker run -it --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
|
||||
```
|
||||
|
||||
# Setting up a model
|
||||
|
||||
to include a model in your backend, add a directory with your model name to your model repository directory. examples of models can be seen here: https://github.com/triton-inference-server/backend/tree/main/examples/model_repos/minimal_models
|
||||
|
||||
make sure to adjust the input correctly in the config.pbtxt file, and save a vmfb file under 1/model.vmfb
|
||||
|
||||
# CUDA
|
||||
|
||||
if you're having issues with cuda, make sure your correct drivers are installed, and that `nvidia-smi` works, and also make sure that the nvcc compiler is on the path.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include(CMakeFindDependencyMacro)
|
||||
|
||||
get_filename_component(
|
||||
SHARKBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${SHARKBACKEND_CMAKE_DIR})
|
||||
|
||||
if(NOT TARGET SharkBackend::triton-dshark-backend)
|
||||
include("${SHARKBACKEND_CMAKE_DIR}/SharkBackendTargets.cmake")
|
||||
endif()
|
||||
|
||||
set(SHARKBACKEND_LIBRARIES SharkBackend::triton-dshark-backend)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,30 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
{
|
||||
global:
|
||||
TRITONBACKEND_*;
|
||||
local: *;
|
||||
};
|
||||
1
inference/thirdparty/shark-runtime
vendored
1
inference/thirdparty/shark-runtime
vendored
Submodule inference/thirdparty/shark-runtime deleted from 7b82d90c72
@@ -6,15 +6,15 @@ from distutils.sysconfig import get_python_lib
|
||||
import fileinput
|
||||
from pathlib import Path
|
||||
|
||||
# Temorary workaround for transformers/__init__.py.
|
||||
path_to_tranformers_hook = Path(
|
||||
# Temporary workaround for transformers/__init__.py.
|
||||
path_to_transformers_hook = Path(
|
||||
get_python_lib()
|
||||
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
|
||||
)
|
||||
if path_to_tranformers_hook.is_file():
|
||||
if path_to_transformers_hook.is_file():
|
||||
pass
|
||||
else:
|
||||
with open(path_to_tranformers_hook, "w") as f:
|
||||
with open(path_to_transformers_hook, "w") as f:
|
||||
f.write("module_collection_mode = 'pyz+py'")
|
||||
|
||||
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
|
||||
|
||||
@@ -8,19 +8,8 @@ torchvision
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
#these dont work ok osx
|
||||
#iree-tools-tflite
|
||||
#iree-tools-xla
|
||||
#iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow-macos
|
||||
tensorflow-metal
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
tensorflow-probability
|
||||
#jax[cpu]
|
||||
|
||||
# tflitehub dependencies.
|
||||
|
||||
@@ -9,23 +9,13 @@ tabulate
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
iree-tools-tflite
|
||||
iree-tools-xla
|
||||
iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
# Modelling and JAX.
|
||||
gin-config
|
||||
tf-nightly
|
||||
keras
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
diffusers
|
||||
#tensorflow-probability
|
||||
#jax[cpu]
|
||||
|
||||
|
||||
# tflitehub dependencies.
|
||||
Pillow
|
||||
|
||||
# Testing and support.
|
||||
@@ -36,7 +26,7 @@ sacremoses
|
||||
sentencepiece
|
||||
|
||||
# web dependecies.
|
||||
gradio
|
||||
gradio==3.44.3
|
||||
altair
|
||||
scipy
|
||||
|
||||
|
||||
@@ -17,14 +17,16 @@ pytest-forked
|
||||
Pillow
|
||||
parameterized
|
||||
|
||||
#shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
tokenizers==0.13.3
|
||||
transformers
|
||||
diffusers
|
||||
#accelerate is now required for diffusers import from ckpt.
|
||||
accelerate
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
gradio==4.7.1
|
||||
altair
|
||||
omegaconf
|
||||
# 0.3.2 doesn't have binaries for arm64
|
||||
@@ -40,10 +42,16 @@ tiktoken # for codegen
|
||||
joblib # for langchain
|
||||
timm # for MiniGPT4
|
||||
langchain
|
||||
einops # for zoedepth
|
||||
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
pyinstaller
|
||||
|
||||
# vicuna quantization
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@dev
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
|
||||
|
||||
# For quantized GPTQ models
|
||||
optimum
|
||||
auto_gptq
|
||||
|
||||
@@ -4,7 +4,7 @@ import base64
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def upscaler_test():
|
||||
def upscaler_test(verbose=False):
|
||||
# Define values here
|
||||
prompt = ""
|
||||
negative_prompt = ""
|
||||
@@ -44,10 +44,17 @@ def upscaler_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"response from server was : {res.status_code}")
|
||||
print(
|
||||
f"[upscaler] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
|
||||
def img2img_test():
|
||||
def img2img_test(verbose=False):
|
||||
# Define values here
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
@@ -87,7 +94,16 @@ def img2img_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"response from server was : {res.status_code}")
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[img2img] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
# NOTE Uncomment below to save the picture
|
||||
|
||||
@@ -103,7 +119,7 @@ def img2img_test():
|
||||
# response_img.save(r"rest_api_tests/response_img.png")
|
||||
|
||||
|
||||
def inpainting_test():
|
||||
def inpainting_test(verbose=False):
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
@@ -150,10 +166,17 @@ def inpainting_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[Inpainting] response from server was : {res.status_code}")
|
||||
print(
|
||||
f"[inpaint] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
|
||||
def outpainting_test():
|
||||
def outpainting_test(verbose=False):
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
@@ -200,10 +223,17 @@ def outpainting_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[Outpaint] response from server was : {res.status_code}")
|
||||
print(
|
||||
f"[outpaint] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
|
||||
def txt2img_test():
|
||||
def txt2img_test(verbose=False):
|
||||
prompt = "Paint a rabbit in a top hate"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
@@ -232,12 +262,119 @@ def txt2img_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[txt2img] response from server was : {res.status_code}")
|
||||
print(
|
||||
f"[txt2img] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
|
||||
def sd_models_test(verbose=False):
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/sd-models"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
res = requests.get(url=url, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[sd_models] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
|
||||
|
||||
|
||||
def sd_samplers_test(verbose=False):
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/samplers"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
res = requests.get(url=url, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[sd_samplers] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
|
||||
|
||||
|
||||
def options_test(verbose=False):
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/options"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
res = requests.get(url=url, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[options] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
|
||||
|
||||
|
||||
def cmd_flags_test(verbose=False):
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/cmd-flags"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
res = requests.get(url=url, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[cmd-flags] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
txt2img_test()
|
||||
img2img_test()
|
||||
upscaler_test()
|
||||
inpainting_test()
|
||||
outpainting_test()
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Exercises the Stable Diffusion REST API of Shark. Make sure "
|
||||
"Shark is running in API mode on 127.0.0.1:8080 before running"
|
||||
"this script."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help=(
|
||||
"also display selected info from the JSON response for "
|
||||
"successful requests"
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sd_models_test(args.verbose)
|
||||
sd_samplers_test(args.verbose)
|
||||
options_test(args.verbose)
|
||||
cmd_flags_test(args.verbose)
|
||||
txt2img_test(args.verbose)
|
||||
img2img_test(args.verbose)
|
||||
upscaler_test(args.verbose)
|
||||
inpainting_test(args.verbose)
|
||||
outpainting_test(args.verbose)
|
||||
|
||||
@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
|
||||
@@ -86,6 +86,7 @@ $PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
|
||||
if [ "$torch_mlir_bin" = true ]; then
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
|
||||
$PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
@@ -110,7 +111,7 @@ else
|
||||
fi
|
||||
if [[ -z "${NO_BACKEND}" ]]; then
|
||||
echo "Installing ${RUNTIME}..."
|
||||
$PYTHON -m pip install --pre --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
|
||||
$PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler iree-runtime
|
||||
else
|
||||
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
|
||||
fi
|
||||
@@ -128,16 +129,21 @@ if [[ ! -z "${IMPORTER}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/cpu/
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
PYTORCH_URL=https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
PYTORCH_URL=https://download.pytorch.org/whl/nightly/cpu/
|
||||
fi
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL}
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
|
||||
T_VER=$($PYTHON -m pip show torch | grep Version)
|
||||
TORCH_VERSION=${T_VER:9:17}
|
||||
T_VER_MIN=${T_VER:14:12}
|
||||
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
|
||||
TV_VERSION=${TV_VER:9:18}
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
|
||||
TV_VER_MAJ=${TV_VER:9:6}
|
||||
$PYTHON -m pip uninstall -y torchvision
|
||||
$PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu118."
|
||||
else
|
||||
@@ -146,7 +152,7 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
fi
|
||||
|
||||
if [[ -z "${NO_BREVITAS}" ]]; then
|
||||
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@llm
|
||||
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev
|
||||
fi
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
|
||||
|
||||
@@ -177,7 +177,7 @@ def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
mlir_model = str(module)
|
||||
func_name = "forward"
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if __name__ == "__main__":
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, mlir_dialect="mhlo")
|
||||
shark_module = SharkInference(minilm_mlir, mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
output_idx = 0
|
||||
data_idx = 1
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user