mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
69 Commits
20230929.9
...
20231111.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4a908c3ea | ||
|
|
6285430d8a | ||
|
|
51afe19e20 | ||
|
|
31005bcf73 | ||
|
|
f41ad87ef6 | ||
|
|
d811524a00 | ||
|
|
51e1bd1c5d | ||
|
|
db89b1bdc1 | ||
|
|
2754e2e257 | ||
|
|
ab0e870c43 | ||
|
|
fb30e8c226 | ||
|
|
a07d542400 | ||
|
|
ad55cb696f | ||
|
|
488a172292 | ||
|
|
500c4f2306 | ||
|
|
92b694db4d | ||
|
|
322874f7f9 | ||
|
|
5001db3415 | ||
|
|
71846344a2 | ||
|
|
72e27c96fc | ||
|
|
7963abb8ec | ||
|
|
98244232dd | ||
|
|
679a452139 | ||
|
|
72c0a8abc8 | ||
|
|
ea920f2955 | ||
|
|
486202377a | ||
|
|
0c38c33d0a | ||
|
|
841773fa32 | ||
|
|
0361db46f9 | ||
|
|
a012433ffd | ||
|
|
5061193da3 | ||
|
|
bff48924be | ||
|
|
825b36cbdd | ||
|
|
134441957d | ||
|
|
7cd14fdc47 | ||
|
|
e6cb5cef57 | ||
|
|
66abee8e5b | ||
|
|
4797bb89f5 | ||
|
|
205e57683a | ||
|
|
2866d665ee | ||
|
|
71d25ec5d8 | ||
|
|
202ffff67b | ||
|
|
0b77059628 | ||
|
|
a208302bb9 | ||
|
|
b83d32fafe | ||
|
|
0a618e1863 | ||
|
|
a731eb6ed4 | ||
|
|
2004d16945 | ||
|
|
6e409bfb77 | ||
|
|
77727d149c | ||
|
|
66f6e79d68 | ||
|
|
3b825579a7 | ||
|
|
9f0a421764 | ||
|
|
c28682110c | ||
|
|
caf6cc5d8f | ||
|
|
8614a18474 | ||
|
|
86c1c0c215 | ||
|
|
8bb364bcb8 | ||
|
|
7abddd01ec | ||
|
|
2a451fa0c7 | ||
|
|
9c4610b9da | ||
|
|
a38cc9d216 | ||
|
|
1c382449ec | ||
|
|
7cc9b3f8e8 | ||
|
|
e54517e967 | ||
|
|
326327a799 | ||
|
|
785b65c7b0 | ||
|
|
0d16c81687 | ||
|
|
8dd7850c69 |
3
.github/workflows/test-models.yml
vendored
3
.github/workflows/test-models.yml
vendored
@@ -137,7 +137,8 @@ jobs:
|
||||
source shark.venv/bin/activate
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
|
||||
# disabled due to a low-visibility memory issue with pytest on macos.
|
||||
# pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -182,7 +182,7 @@ generated_imgs/
|
||||
|
||||
# Custom model related artefacts
|
||||
variants.json
|
||||
models/
|
||||
/models/
|
||||
|
||||
# models folder
|
||||
apps/stable_diffusion/web/models/
|
||||
@@ -199,3 +199,6 @@ apps/stable_diffusion/web/EBWebView/
|
||||
|
||||
# Llama2 tokenizer configs
|
||||
llama2_tokenizer_configs/
|
||||
|
||||
# Webview2 runtime artefacts
|
||||
EBWebView/
|
||||
|
||||
10
README.md
10
README.md
@@ -254,7 +254,6 @@ if you want to instead incorporate this into a python script, you can pass the `
|
||||
```
|
||||
shark_module = SharkInference(
|
||||
mlir_model,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
dispatch_benchmarks="all",
|
||||
@@ -297,7 +296,7 @@ torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
|
||||
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input))
|
||||
|
||||
@@ -320,12 +319,17 @@ mhlo_ir = r"""builtin.module {
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
|
||||
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((arg0, arg1))
|
||||
```
|
||||
</details>
|
||||
|
||||
## Examples Using the REST API
|
||||
|
||||
* [Setting up SHARK for use with Blender](./docs/shark_sd_blender.md)
|
||||
* [Setting up SHARK for use with Koboldcpp](./docs/shark_sd_koboldcpp.md)
|
||||
|
||||
## Supported and Validated Models
|
||||
|
||||
SHARK is maintained to support the latest innovations in ML Models:
|
||||
|
||||
@@ -20,7 +20,7 @@ import gc
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
# Brevitas
|
||||
@@ -256,6 +256,11 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
bytecode = save_mlir(
|
||||
bytecode,
|
||||
model_name=f"h2ogpt_{precision}",
|
||||
frontend="torch",
|
||||
)
|
||||
return bytecode
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
|
||||
@@ -65,8 +65,8 @@ tiktoken==0.4.0
|
||||
openai==0.27.8
|
||||
|
||||
# optional for chat with PDF
|
||||
langchain==0.0.202
|
||||
pypdf==3.12.2
|
||||
langchain==0.0.325
|
||||
pypdf==3.17.0
|
||||
# avoid textract, requires old six
|
||||
#textract==1.6.5
|
||||
|
||||
|
||||
442
apps/language_models/scripts/llama_ir_conversion_utils.py
Normal file
442
apps/language_models/scripts/llama_ir_conversion_utils.py
Normal file
@@ -0,0 +1,442 @@
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
from argparse import RawTextHelpFormatter
|
||||
import re, gc
|
||||
|
||||
"""
|
||||
This script can be used as a standalone utility to convert IRs to dynamic + combine them.
|
||||
Following are the various ways this script can be used :-
|
||||
a. To convert a single Linalg IR to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO FIRST IR>
|
||||
b. To convert two Linalg IRs to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>
|
||||
c. To combine two Linalg IRs into one:
|
||||
--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
d. To convert both IRs into dynamic as well as combine the IRs:
|
||||
--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
|
||||
NOTE: For dynamic you'll also need to provide the following set of flags:-
|
||||
i. For First Llama : --dynamic_input_size (DEFAULT: 19)
|
||||
ii. For Second Llama: --model_name (DEFAULT: llama2_7b)
|
||||
--precision (DEFAULT: 'int4')
|
||||
You may use --save_dynamic to also save the dynamic IR in option d above.
|
||||
Else for option a. and b. the dynamic IR(s) will get saved by default.
|
||||
"""
|
||||
|
||||
|
||||
def combine_mlir_scripts(
|
||||
first_vicuna_mlir,
|
||||
second_vicuna_mlir,
|
||||
output_name,
|
||||
return_ir=True,
|
||||
):
|
||||
print(f"[DEBUG] combining first and second mlir")
|
||||
print(f"[DEBUG] output_name = {output_name}")
|
||||
maps1 = []
|
||||
maps2 = []
|
||||
constants = set()
|
||||
f1 = []
|
||||
f2 = []
|
||||
|
||||
print(f"[DEBUG] processing first vicuna mlir")
|
||||
first_vicuna_mlir = first_vicuna_mlir.splitlines()
|
||||
while first_vicuna_mlir:
|
||||
line = first_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps1.append(line)
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "first_vicuna_forward", line)
|
||||
f1.append(line)
|
||||
f1 = f1[:-1]
|
||||
del first_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps1):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
|
||||
maps1[i] = map_line
|
||||
f1 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
|
||||
for func_line in f1
|
||||
]
|
||||
|
||||
print(f"[DEBUG] processing second vicuna mlir")
|
||||
second_vicuna_mlir = second_vicuna_mlir.splitlines()
|
||||
while second_vicuna_mlir:
|
||||
line = second_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps2.append(line)
|
||||
elif "global_seed" in line:
|
||||
continue
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "second_vicuna_forward", line)
|
||||
f2.append(line)
|
||||
f2 = f2[:-1]
|
||||
del second_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps2):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
|
||||
maps2[i] = map_line
|
||||
f2 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
|
||||
for func_line in f2
|
||||
]
|
||||
|
||||
module_start = 'module attributes {torch.debug_module_name = "_lambda"} {'
|
||||
module_end = "}"
|
||||
|
||||
global_vars = []
|
||||
vnames = []
|
||||
global_var_loading1 = []
|
||||
global_var_loading2 = []
|
||||
|
||||
print(f"[DEBUG] processing constants")
|
||||
counter = 0
|
||||
constants = list(constants)
|
||||
while constants:
|
||||
constant = constants.pop(0)
|
||||
vname, vbody = constant.split("=")
|
||||
vname = re.sub("%", "", vname)
|
||||
vname = vname.strip()
|
||||
vbody = re.sub("arith.constant", "", vbody)
|
||||
vbody = vbody.strip()
|
||||
if len(vbody.split(":")) < 2:
|
||||
print(constant)
|
||||
vdtype = vbody.split(":")[-1].strip()
|
||||
fixed_vdtype = vdtype
|
||||
if "c1_i64" in vname:
|
||||
print(constant)
|
||||
counter += 1
|
||||
if counter == 2:
|
||||
counter = 0
|
||||
print("detected duplicate")
|
||||
continue
|
||||
vnames.append(vname)
|
||||
if "true" not in vname:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
else:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : i1"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
|
||||
new_f1, new_f2 = [], []
|
||||
|
||||
print(f"[DEBUG] processing f1")
|
||||
for line in f1:
|
||||
if "func.func" in line:
|
||||
new_f1.append(line)
|
||||
for global_var in global_var_loading1:
|
||||
new_f1.append(global_var)
|
||||
else:
|
||||
new_f1.append(line)
|
||||
|
||||
print(f"[DEBUG] processing f2")
|
||||
for line in f2:
|
||||
if "func.func" in line:
|
||||
new_f2.append(line)
|
||||
for global_var in global_var_loading2:
|
||||
if (
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
in global_var
|
||||
):
|
||||
print(global_var)
|
||||
new_f2.append(global_var)
|
||||
else:
|
||||
new_f2.append(line)
|
||||
|
||||
f1 = new_f1
|
||||
f2 = new_f2
|
||||
|
||||
del new_f1
|
||||
del new_f2
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
[
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
|
||||
for x in [maps1, maps2, global_vars, f1, f2]
|
||||
]
|
||||
)
|
||||
|
||||
# doing it this way rather than assembling the whole string
|
||||
# to prevent OOM with 64GiB RAM when encoding the file.
|
||||
|
||||
print(f"[DEBUG] Saving mlir to {output_name}")
|
||||
with open(output_name, "w+") as f_:
|
||||
f_.writelines(line + "\n" for line in maps1)
|
||||
f_.writelines(line + "\n" for line in maps2)
|
||||
f_.writelines(line + "\n" for line in [module_start])
|
||||
f_.writelines(line + "\n" for line in global_vars)
|
||||
f_.writelines(line + "\n" for line in f1)
|
||||
f_.writelines(line + "\n" for line in f2)
|
||||
f_.writelines(line + "\n" for line in [module_end])
|
||||
|
||||
del maps1
|
||||
del maps2
|
||||
del module_start
|
||||
del global_vars
|
||||
del f1
|
||||
del f2
|
||||
del module_end
|
||||
gc.collect()
|
||||
|
||||
if return_ir:
|
||||
print(f"[DEBUG] Reading combined mlir back in")
|
||||
with open(output_name, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def write_in_dynamic_inputs0(module, dynamic_input_size):
|
||||
print("[DEBUG] writing dynamic inputs to first vicuna")
|
||||
# Current solution for ensuring mlir files support dynamic inputs
|
||||
# TODO: find a more elegant way to implement this
|
||||
new_lines = []
|
||||
module = module.splitlines()
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
|
||||
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
|
||||
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
|
||||
continue
|
||||
|
||||
new_lines.append(line)
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def write_in_dynamic_inputs1(module, model_name, precision):
|
||||
print("[DEBUG] writing dynamic inputs to second vicuna")
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "c19_i64" in line:
|
||||
line = re.sub("c19_i64", "dim_i64", line)
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
if "x20x" in line or "<20x" in line:
|
||||
line = re.sub("20x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
|
||||
if " 20," in line:
|
||||
line = re.sub(" 20,", " %dimp1,", line)
|
||||
return line
|
||||
|
||||
module = module.splitlines()
|
||||
new_lines = []
|
||||
|
||||
# Using a while loop and the pop method to avoid creating a copy of module
|
||||
if "llama2_13b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x40x?x128x"
|
||||
elif "llama2_70b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x8x?x128x"
|
||||
else:
|
||||
pkv_tensor_shape = "tensor<1x32x?x128x"
|
||||
if precision in ["fp16", "int4", "int8"]:
|
||||
pkv_tensor_shape += "f16>"
|
||||
else:
|
||||
pkv_tensor_shape += "f32>"
|
||||
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append(
|
||||
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
)
|
||||
continue
|
||||
if "%c2 = arith.constant 2 : index" in line:
|
||||
continue
|
||||
if "%c20_i64 = arith.constant 20 : i64" in line:
|
||||
new_lines.append("%c1_i64 = arith.constant 1 : i64")
|
||||
new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
|
||||
new_lines.append(
|
||||
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
|
||||
)
|
||||
continue
|
||||
line = remove_constant_dim(line)
|
||||
new_lines.append(line)
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def save_dynamic_ir(ir_to_save, output_file):
|
||||
if not ir_to_save:
|
||||
return
|
||||
# We only get string output from the dynamic conversion utility.
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
print(ir_to_save)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="llama ir utility",
|
||||
description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n"
|
||||
+ "\tFollowing are the various ways this script can be used :-\n"
|
||||
+ "\t\ta. To convert a single Linalg IR to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO FIRST IR>\n"
|
||||
+ "\t\tb. To convert two Linalg IRs to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\tc. To combine two Linalg IRs into one:\n"
|
||||
+ "\t\t\t--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n"
|
||||
+ "\t\t\t--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n\n"
|
||||
+ "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n"
|
||||
+ "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n"
|
||||
+ "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n"
|
||||
+ "\t\t\t--precision (DEFAULT: 'int4')\n"
|
||||
+ "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n"
|
||||
+ "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
"-p",
|
||||
default="int4",
|
||||
choices=["fp32", "fp16", "int8", "int4"],
|
||||
help="Precision of the concerned IR",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="llama2_7b",
|
||||
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
|
||||
help="Specify which model to run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first_ir_path",
|
||||
default=None,
|
||||
help="path to first llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--second_ir_path",
|
||||
default=None,
|
||||
help="path to second llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic_input_size",
|
||||
type=int,
|
||||
default=19,
|
||||
help="Specify the static input size to replace with dynamic dim.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Save the individual IR(s) after converting to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--combine",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
dynamic = args.dynamic
|
||||
combine = args.combine
|
||||
assert (
|
||||
dynamic or combine
|
||||
), "neither `dynamic` nor `combine` flag is turned on"
|
||||
first_ir_path = args.first_ir_path
|
||||
second_ir_path = args.second_ir_path
|
||||
assert first_ir_path or second_ir_path, "no input ir has been provided"
|
||||
if combine:
|
||||
assert (
|
||||
first_ir_path and second_ir_path
|
||||
), "you will need to provide both IRs to combine"
|
||||
precision = args.precision
|
||||
model_name = args.model_name
|
||||
dynamic_input_size = args.dynamic_input_size
|
||||
save_dynamic = args.save_dynamic
|
||||
|
||||
print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}")
|
||||
print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}")
|
||||
|
||||
if dynamic and not combine:
|
||||
save_dynamic = True
|
||||
|
||||
first_ir = None
|
||||
first_dynamic_ir_name = None
|
||||
second_ir = None
|
||||
second_dynamic_ir_name = None
|
||||
if first_ir_path:
|
||||
first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic"
|
||||
with open(first_ir_path, "r") as f:
|
||||
first_ir = f.read()
|
||||
if second_ir_path:
|
||||
second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic"
|
||||
with open(second_ir_path, "r") as f:
|
||||
second_ir = f.read()
|
||||
if dynamic:
|
||||
first_ir = (
|
||||
write_in_dynamic_inputs0(first_ir, dynamic_input_size)
|
||||
if first_ir
|
||||
else None
|
||||
)
|
||||
second_ir = (
|
||||
write_in_dynamic_inputs1(second_ir, model_name, precision)
|
||||
if second_ir
|
||||
else None
|
||||
)
|
||||
if save_dynamic:
|
||||
save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir")
|
||||
save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir")
|
||||
|
||||
if combine:
|
||||
combine_mlir_scripts(
|
||||
first_ir,
|
||||
second_ir,
|
||||
f"{model_name}_{precision}.mlir",
|
||||
return_ir=False,
|
||||
)
|
||||
@@ -4,6 +4,7 @@ import re
|
||||
import gc
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from statistics import mean, stdev
|
||||
from tqdm import tqdm
|
||||
from typing import List, Tuple
|
||||
import subprocess
|
||||
@@ -43,12 +44,18 @@ from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
SecondVicuna13B,
|
||||
SecondVicuna70B,
|
||||
)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model_gpu import (
|
||||
FirstVicunaGPU,
|
||||
SecondVicuna7BGPU,
|
||||
SecondVicuna13BGPU,
|
||||
SecondVicuna70BGPU,
|
||||
)
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import get_f16_inputs
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
|
||||
@@ -103,7 +110,7 @@ parser.add_argument(
|
||||
"--download_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download vmfb from sharktank, system dependent, YMMV",
|
||||
help="Download vmfb from sharktank, system dependent, YMMV",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
@@ -130,6 +137,44 @@ parser.add_argument(
|
||||
default="",
|
||||
help="Specify target triple for vulkan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--Xiree_compile",
|
||||
action='append',
|
||||
default=[],
|
||||
help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments."
|
||||
)
|
||||
|
||||
# Microbenchmarking options.
|
||||
parser.add_argument(
|
||||
"--enable_microbenchmark",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Enables the microbenchmarking mode (non-interactive). Uses the system and the user prompt from args.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microbenchmark_iterations",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of microbenchmark iterations. Default: 5.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microbenchmark_num_tokens",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Generate an exact number of output tokens. Default: 512.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system_prompt",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify the system prompt. This is only used with `--enable_microbenchmark`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user_prompt",
|
||||
type=str,
|
||||
default="Hi",
|
||||
help="Specify the user prompt. This is only used with `--enable_microbenchmark`",
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
@@ -399,7 +444,7 @@ class VicunaBase(SharkLLMBase):
|
||||
is_first=is_first,
|
||||
)
|
||||
else:
|
||||
token = token.to(torch.int64).reshape([1, 1])
|
||||
token = torch.tensor(token).reshape([1, 1])
|
||||
second_input = (token,) + tuple(past_key_values)
|
||||
output = self.shark_model(
|
||||
"second_vicuna_forward", second_input, send_to_host=False
|
||||
@@ -409,6 +454,9 @@ class VicunaBase(SharkLLMBase):
|
||||
_logits = output["logits"]
|
||||
_past_key_values = output["past_key_values"]
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
elif "cpu" in self.device:
|
||||
_past_key_values = output[1:]
|
||||
_token = int(output[0].to_host())
|
||||
else:
|
||||
_logits = torch.tensor(output[0].to_host())
|
||||
_past_key_values = output[1:]
|
||||
@@ -418,9 +466,10 @@ class VicunaBase(SharkLLMBase):
|
||||
ret_dict = {
|
||||
"token": _token,
|
||||
"detok": _detok,
|
||||
"logits": _logits,
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
if "cpu" not in self.device:
|
||||
ret_dict["logits"] = _logits
|
||||
|
||||
if cli:
|
||||
print(f" token : {_token} | detok : {_detok}")
|
||||
@@ -641,9 +690,7 @@ class ShardedVicuna(VicunaBase):
|
||||
mlir_path = Path(f"lmhead.mlir")
|
||||
vmfb_path = Path(f"lmhead.vmfb")
|
||||
if mlir_path.exists():
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
print(f"Found bytecode module at {mlir_path}.")
|
||||
else:
|
||||
hidden_states = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
@@ -668,12 +715,10 @@ class ShardedVicuna(VicunaBase):
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
f_ = open(f"lmhead.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
mlir_path = filepath
|
||||
|
||||
shark_module = SharkInference(
|
||||
bytecode,
|
||||
mlir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
@@ -693,9 +738,7 @@ class ShardedVicuna(VicunaBase):
|
||||
mlir_path = Path(f"norm.mlir")
|
||||
vmfb_path = Path(f"norm.vmfb")
|
||||
if mlir_path.exists():
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
print(f"Found bytecode module at {mlir_path}.")
|
||||
else:
|
||||
hidden_states = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
@@ -714,12 +757,10 @@ class ShardedVicuna(VicunaBase):
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
f_ = open(f"norm.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
mlir_path = filepath
|
||||
|
||||
shark_module = SharkInference(
|
||||
bytecode,
|
||||
mlir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
@@ -739,9 +780,7 @@ class ShardedVicuna(VicunaBase):
|
||||
mlir_path = Path(f"embedding.mlir")
|
||||
vmfb_path = Path(f"embedding.vmfb")
|
||||
if mlir_path.exists():
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
print(f"Found bytecode module at {mlir_path}.")
|
||||
else:
|
||||
input_ids = torch_mlir.TensorPlaceholder.like(
|
||||
input_ids, dynamic_axes=[1]
|
||||
@@ -765,12 +804,10 @@ class ShardedVicuna(VicunaBase):
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
f_ = open(f"embedding.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
mlir_path = filepath
|
||||
|
||||
shark_module = SharkInference(
|
||||
bytecode,
|
||||
mlir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
@@ -1220,6 +1257,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
hf_auth_token: str = None,
|
||||
max_num_tokens=512,
|
||||
min_num_tokens=0,
|
||||
device="cpu",
|
||||
vulkan_target_triple="",
|
||||
precision="int8",
|
||||
@@ -1249,6 +1287,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
|
||||
print(f"[DEBUG] hf model name: {self.hf_model_path}")
|
||||
self.max_sequence_length = 256
|
||||
self.min_num_tokens = min_num_tokens
|
||||
self.device = device
|
||||
self.vulkan_target_triple = vulkan_target_triple
|
||||
self.device_id = device_id
|
||||
@@ -1270,6 +1309,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
|
||||
def get_model_path(self, suffix="mlir"):
|
||||
safe_device = self.device.split("-")[0]
|
||||
safe_device = safe_device.split("://")[0]
|
||||
if suffix in ["mlirbc", "mlir"]:
|
||||
return Path(f"{self.model_name}_{self.precision}.{suffix}")
|
||||
|
||||
@@ -1426,10 +1466,12 @@ class UnshardedVicuna(VicunaBase):
|
||||
print(f"[DEBUG] vmfb found at {self.vicuna_vmfb_path.absolute()}")
|
||||
return
|
||||
|
||||
print(f"[DEBUG] vmfb not found")
|
||||
print(f"[DEBUG] vmfb not found (search path: {self.vicuna_vmfb_path})")
|
||||
mlir_generated = False
|
||||
for suffix in ["mlirbc", "mlir"]:
|
||||
self.vicuna_mlir_path = self.get_model_path(suffix)
|
||||
if "cpu" in self.device and "llama2_7b" in self.vicuna_mlir_path.name:
|
||||
self.vicuna_mlir_path = Path("llama2_7b_int4_f32.mlir")
|
||||
if not self.vicuna_mlir_path.exists() and self.load_mlir_from_shark_tank:
|
||||
print(
|
||||
f"Looking into gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}"
|
||||
@@ -1441,18 +1483,12 @@ class UnshardedVicuna(VicunaBase):
|
||||
)
|
||||
if self.vicuna_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
|
||||
with open(self.vicuna_mlir_path, "rb") as f:
|
||||
combined_module = f.read()
|
||||
combined_module = self.vicuna_mlir_path.absolute()
|
||||
mlir_generated = True
|
||||
break
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] mlir not found")
|
||||
# Disabling this path of IR generation for now as it is broken.
|
||||
print("Please check if the mlir file is present at the shark tank. Exiting.")
|
||||
self.shark_model = None
|
||||
sys.exit()
|
||||
return
|
||||
|
||||
print("[DEBUG] generating mlir on device")
|
||||
# Select a compilation prompt such that the resulting input_ids
|
||||
@@ -1474,13 +1510,24 @@ class UnshardedVicuna(VicunaBase):
|
||||
compilation_input_ids
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
if "cpu" in self.device:
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32" if self.device=="cpu" else "fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = FirstVicunaGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32" if self.device=="cpu" else "fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
ts_graph = import_with_fx(
|
||||
@@ -1512,6 +1559,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(first_model_path[:-5]+"_torch.mlir", "w+") as f:
|
||||
f.write(str(first_module))
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
first_module,
|
||||
@@ -1566,30 +1616,62 @@ class UnshardedVicuna(VicunaBase):
|
||||
for _ in range(total_tuple)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
elif self.model_name == "llama2_70b":
|
||||
model = SecondVicuna70B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
if "cpu" in self.device:
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
elif self.model_name == "llama2_70b":
|
||||
model = SecondVicuna70B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = SecondVicuna7B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp32",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = SecondVicuna7B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13BGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
elif self.model_name == "llama2_70b":
|
||||
model = SecondVicuna70BGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = SecondVicuna7BGPU(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
"fp16",
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
ts_graph = import_with_fx(
|
||||
@@ -1626,6 +1708,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
if self.cache_vicunas:
|
||||
with open(second_model_path[:-5]+"_torch.mlir", "w+") as f:
|
||||
f.write(str(second_module))
|
||||
run_pipeline_with_repro_report(
|
||||
second_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
@@ -1659,6 +1744,12 @@ class UnshardedVicuna(VicunaBase):
|
||||
second_module,
|
||||
self.vicuna_mlir_path,
|
||||
)
|
||||
combined_module = save_mlir(
|
||||
combined_module,
|
||||
model_name="combined_llama",
|
||||
mlir_dialect="tm_tensor",
|
||||
dir=self.vicuna_mlir_path,
|
||||
)
|
||||
del first_module, second_module
|
||||
|
||||
print(self.device)
|
||||
@@ -1709,7 +1800,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
prefill_time = time.time() - prefill_st_time
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
if "cpu" not in self.device:
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["past_key_values"]
|
||||
detok = generated_token_op["detok"]
|
||||
yield detok, None, prefill_time
|
||||
@@ -1718,14 +1810,15 @@ class UnshardedVicuna(VicunaBase):
|
||||
if cli:
|
||||
print(f"Assistant: {detok}", end=" ", flush=True)
|
||||
|
||||
for _ in range(self.max_num_tokens - 2):
|
||||
for idx in range(self.max_num_tokens):
|
||||
params = {
|
||||
"token": token,
|
||||
"is_first": False,
|
||||
"logits": logits,
|
||||
"past_key_values": pkv,
|
||||
"sv": self.shark_model,
|
||||
}
|
||||
if "cpu" not in self.device:
|
||||
params["logits"] = logits
|
||||
|
||||
decode_st_time = time.time()
|
||||
generated_token_op = self.generate_new_token(
|
||||
@@ -1734,11 +1827,12 @@ class UnshardedVicuna(VicunaBase):
|
||||
decode_time_ms = (time.time() - decode_st_time)*1000
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
if "cpu" not in self.device:
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["past_key_values"]
|
||||
detok = generated_token_op["detok"]
|
||||
|
||||
if token == 2:
|
||||
if token == 2 and idx >= self.min_num_tokens:
|
||||
break
|
||||
res_tokens.append(token)
|
||||
if detok == "<0x0A>":
|
||||
@@ -1823,7 +1917,8 @@ def create_prompt(model_name, history):
|
||||
if __name__ == "__main__":
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
_extra_args = []
|
||||
_extra_args = list(args.Xiree_compile)
|
||||
|
||||
device_id = None
|
||||
# Process vulkan target triple.
|
||||
# TODO: This feature should just be in a common utils for other LLMs and in general
|
||||
@@ -1846,7 +1941,7 @@ if __name__ == "__main__":
|
||||
device_id = id
|
||||
break
|
||||
id += 1
|
||||
|
||||
|
||||
assert device_id, f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
# Step 2. Add a few flags targetting specific hardwares.
|
||||
if "rdna" in vulkan_target_triple:
|
||||
@@ -1854,7 +1949,7 @@ if __name__ == "__main__":
|
||||
"--iree-spirv-index-bits=64",
|
||||
]
|
||||
_extra_args = _extra_args + flags_to_add
|
||||
|
||||
|
||||
|
||||
vic = None
|
||||
if not args.sharded:
|
||||
@@ -1868,10 +1963,18 @@ if __name__ == "__main__":
|
||||
if args.vicuna_vmfb_path is None
|
||||
else Path(args.vicuna_vmfb_path)
|
||||
)
|
||||
min_tokens = 0
|
||||
max_tokens = 512
|
||||
if args.enable_microbenchmark:
|
||||
min_tokens = max_tokens = args.microbenchmark_num_tokens
|
||||
|
||||
vic = UnshardedVicuna(
|
||||
model_name=args.model_name,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
max_num_tokens=max_tokens,
|
||||
min_num_tokens=min_tokens,
|
||||
device=args.device,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
precision=args.precision,
|
||||
vicuna_mlir_path=vic_mlir_path,
|
||||
vicuna_vmfb_path=vic_vmfb_path,
|
||||
@@ -1897,17 +2000,6 @@ if __name__ == "__main__":
|
||||
weight_group_size=args.weight_group_size,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
if args.model_name == "vicuna":
|
||||
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
else:
|
||||
system_message = """System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
as helpfully as possible, while being safe. Your answers should not
|
||||
include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal
|
||||
content. Please ensure that your responses are socially unbiased and positive
|
||||
in nature. If a question does not make any sense, or is not factually coherent,
|
||||
explain why instead of answering something not correct. If you don't know the
|
||||
answer to a question, please don't share false information."""
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
history = []
|
||||
|
||||
@@ -1917,12 +2009,55 @@ if __name__ == "__main__":
|
||||
"llama2_13b": "llama2_13b=>meta-llama/Llama-2-13b-chat-hf",
|
||||
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
|
||||
}
|
||||
|
||||
iteration = 0
|
||||
|
||||
prefill_times = []
|
||||
avg_decode_speed = []
|
||||
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
history.append([user_prompt, ""])
|
||||
prompt = create_prompt(args.model_name, history)
|
||||
for text, msg in vic.generate(prompt, cli=True):
|
||||
if "formatted" in msg:
|
||||
print("Response:", text)
|
||||
iteration += 1
|
||||
if not args.enable_microbenchmark:
|
||||
user_prompt = input("User prompt: ")
|
||||
history.append([user_prompt, ""])
|
||||
prompt = create_prompt(args.model_name, history)
|
||||
else:
|
||||
if iteration > args.microbenchmark_iterations:
|
||||
break
|
||||
user_prompt = args.user_prompt
|
||||
prompt = args.system_prompt + user_prompt
|
||||
history = [[user_prompt, ""]]
|
||||
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in vic.generate(prompt, cli=True):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
prefill_times.append(prefill_time)
|
||||
avg_decode_speed.append(tokens_per_sec)
|
||||
|
||||
print("\nResponse:", text.strip())
|
||||
print(f"\nNum tokens: {token_count}")
|
||||
print(f"Prefill: {prefill_time:.2f} seconds")
|
||||
print(f"Decode: {tokens_per_sec:.2f} tokens/s")
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
|
||||
if args.enable_microbenchmark:
|
||||
print("\n### Final Statistics ###")
|
||||
print("Number of iterations:", iteration - 1)
|
||||
print(f"Prefill: avg. {mean(prefill_times):.2f} s, stdev {stdev(prefill_times):.2f}")
|
||||
print(f"Decode: avg. {mean(avg_decode_speed):.2f} tokens/s, stdev {stdev(avg_decode_speed):.2f}")
|
||||
|
||||
598
apps/language_models/src/model_wrappers/falcon_sharded_model.py
Normal file
598
apps/language_models/src/model_wrappers/falcon_sharded_model.py
Normal file
@@ -0,0 +1,598 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class WordEmbeddingsLayer(torch.nn.Module):
|
||||
def __init__(self, word_embedding_layer):
|
||||
super().__init__()
|
||||
self.model = word_embedding_layer
|
||||
|
||||
def forward(self, input_ids):
|
||||
output = self.model.forward(input=input_ids)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledWordEmbeddingsLayer(torch.nn.Module):
|
||||
def __init__(self, compiled_word_embedding_layer):
|
||||
super().__init__()
|
||||
self.model = compiled_word_embedding_layer
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids = input_ids.detach().numpy()
|
||||
new_input_ids = self.model("forward", input_ids)
|
||||
new_input_ids = new_input_ids.reshape(
|
||||
[1, new_input_ids.shape[0], new_input_ids.shape[1]]
|
||||
)
|
||||
return torch.tensor(new_input_ids)
|
||||
|
||||
|
||||
class LNFEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, ln_f):
|
||||
super().__init__()
|
||||
self.model = ln_f
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model.forward(input=hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledLNFEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, ln_f):
|
||||
super().__init__()
|
||||
self.model = ln_f
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach().numpy()
|
||||
new_hidden_states = self.model("forward", (hidden_states,))
|
||||
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class LMHeadEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, embedding_layer):
|
||||
super().__init__()
|
||||
self.model = embedding_layer
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model.forward(input=hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, lm_head):
|
||||
super().__init__()
|
||||
self.model = lm_head
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach().numpy()
|
||||
new_hidden_states = self.model("forward", (hidden_states,))
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class DecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
output = self.model.forward(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
return (output[0], output[1][0], output[1][1])
|
||||
|
||||
|
||||
class CompiledDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
alibi: torch.Tensor = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import get_vmfb_from_path
|
||||
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
print("vmfb path for layer: ", self.falcon_vmfb_path)
|
||||
self.model = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=self.device_index,
|
||||
)
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
new_hidden_states, pkv1, pkv2 = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
del self.model
|
||||
|
||||
return tuple(
|
||||
[
|
||||
torch.tensor(new_hidden_states),
|
||||
tuple(
|
||||
[
|
||||
torch.tensor(pkv1),
|
||||
torch.tensor(pkv2),
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class EightDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
self.falcon_variant = falcon_variant
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
new_pkvs = []
|
||||
for layer in self.model:
|
||||
outputs = layer(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
if self.falcon_variant == "7b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
elif self.falcon_variant == "40b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
)
|
||||
elif self.falcon_variant == "180b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported Falcon variant: ", self.falcon_variant
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledEightDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
alibi: torch.Tensor = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import get_vmfb_from_path
|
||||
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
print("vmfb path for layer: ", self.falcon_vmfb_path)
|
||||
self.model = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=self.device_index,
|
||||
)
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
del self.model
|
||||
|
||||
if self.falcon_variant == "7b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
)
|
||||
elif self.falcon_variant == "40b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
)
|
||||
elif self.falcon_variant == "180b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported Falcon variant: ", self.falcon_variant
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class ShardedFalconModel:
|
||||
def __init__(self, model, layers, word_embeddings, ln_f, lm_head):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.model.transformer.h = torch.nn.modules.container.ModuleList(
|
||||
layers
|
||||
)
|
||||
self.model.transformer.word_embeddings = word_embeddings
|
||||
self.model.transformer.ln_f = ln_f
|
||||
self.model.lm_head = lm_head
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
):
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
).logits[:, -1, :]
|
||||
@@ -54,7 +54,6 @@ from apps.language_models.utils import (
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import get_f16_inputs
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
@@ -7,6 +7,7 @@ class FirstVicuna(torch.nn.Module):
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -15,6 +16,9 @@ class FirstVicuna(torch.nn.Module):
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
@@ -29,7 +33,7 @@ class FirstVicuna(torch.nn.Module):
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float16 if precision == "int4" else torch.float32,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
@@ -43,7 +47,9 @@ class FirstVicuna(torch.nn.Module):
|
||||
def forward(self, input_ids):
|
||||
op = self.model(input_ids=input_ids, use_cache=True)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
token = torch.argmax(op.logits[:, -1, :], dim=1)
|
||||
return_vals.append(token)
|
||||
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
@@ -56,6 +62,7 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -67,6 +74,9 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
@@ -78,7 +88,7 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float16 if precision == "int4" else torch.float32,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
@@ -289,7 +299,8 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
token = torch.argmax(op.logits[:, -1, :], dim=1)
|
||||
return_vals.append(token)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
@@ -302,6 +313,7 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
self,
|
||||
model_path,
|
||||
precision="int8",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -313,6 +325,9 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
@@ -323,7 +338,7 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float16 if precision == "int4" else torch.float32,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
@@ -595,6 +610,7 @@ class SecondVicuna70B(torch.nn.Module):
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -606,6 +622,9 @@ class SecondVicuna70B(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
@@ -617,7 +636,7 @@ class SecondVicuna70B(torch.nn.Module):
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float16,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
|
||||
1165
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
1165
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,17 @@
|
||||
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
|
||||
from apps.language_models.src.model_wrappers.falcon_sharded_model import (
|
||||
WordEmbeddingsLayer,
|
||||
CompiledWordEmbeddingsLayer,
|
||||
LNFEmbeddingLayer,
|
||||
CompiledLNFEmbeddingLayer,
|
||||
LMHeadEmbeddingLayer,
|
||||
CompiledLMHeadEmbeddingLayer,
|
||||
DecoderLayer,
|
||||
EightDecoderLayer,
|
||||
CompiledDecoderLayer,
|
||||
CompiledEightDecoderLayer,
|
||||
ShardedFalconModel,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
@@ -7,21 +20,22 @@ from io import BytesIO
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stdout
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
|
||||
from transformers.generation import (
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
import copy
|
||||
|
||||
import time
|
||||
import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
import argparse
|
||||
import gc
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="falcon runner",
|
||||
@@ -32,7 +46,13 @@ parser.add_argument(
|
||||
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
|
||||
"--compressed",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Do the compression of sharded layers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
@@ -67,9 +87,16 @@ parser.add_argument(
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication token for falcon-180B model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--sharded",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model as sharded",
|
||||
)
|
||||
|
||||
|
||||
class Falcon(SharkLLMBase):
|
||||
class ShardedFalcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
@@ -85,6 +112,532 @@ class Falcon(SharkLLMBase):
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print("hf_model_path: ", self.hf_model_path)
|
||||
|
||||
if (
|
||||
"180b" in self.model_name
|
||||
and precision != "int4"
|
||||
and hf_auth_token == None
|
||||
):
|
||||
raise ValueError(
|
||||
""" HF auth token required for falcon-180b. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.falcon_vmfb_path = falcon_vmfb_path
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile(compressed=args.compressed)
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
token=self.hf_auth_token,
|
||||
)
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token_id = 11
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
if self.precision == "int4":
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
if self.precision == "int4":
|
||||
falcon_model = falcon_model.to(torch.float32)
|
||||
return falcon_model
|
||||
|
||||
def compile_layer(
|
||||
self, layer, falconCompileInput, layer_id, device_idx=None
|
||||
):
|
||||
self.falcon_mlir_path = Path(
|
||||
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir"
|
||||
)
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
print(f"[DEBUG] Trying to download vmfb from shark_tank")
|
||||
download_public_file(
|
||||
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/vmfb/"
|
||||
+ str(self.falcon_vmfb_path),
|
||||
self.falcon_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=device_idx,
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb, device_idx
|
||||
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
if args.load_mlir_from_shark_tank:
|
||||
# Downloading MLIR from shark_tank
|
||||
print(f"[DEBUG] Trying to download mlir from shark_tank")
|
||||
download_public_file(
|
||||
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/mlir/"
|
||||
+ str(self.falcon_mlir_path),
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(
|
||||
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] generating MLIR locally")
|
||||
if layer_id == "word_embeddings":
|
||||
f16_input_mask = [False]
|
||||
elif layer_id in ["ln_f", "lm_head"]:
|
||||
f16_input_mask = [True]
|
||||
elif "_" in layer_id or type(layer_id) == int:
|
||||
f16_input_mask = [True, False]
|
||||
else:
|
||||
raise ValueError("Unsupported layer: ", layer_id)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
layer,
|
||||
falconCompileInput,
|
||||
is_f16=True,
|
||||
f16_input_mask=f16_input_mask,
|
||||
mlir_type="torchscript",
|
||||
is_gptq=True,
|
||||
)
|
||||
del layer
|
||||
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
falconCompileInput,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
f_ = open(self.falcon_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
|
||||
f_.close()
|
||||
del bytecode
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=self.falcon_mlir_path,
|
||||
device=self.device,
|
||||
mlir_dialect="linalg",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
self.falcon_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
+ [
|
||||
"--iree-llvmcpu-use-fast-min-max-ops",
|
||||
]
|
||||
if self.precision == "int4"
|
||||
else [],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module, device_idx
|
||||
|
||||
def compile(self, compressed=False):
|
||||
sample_input_ids = torch.zeros([100], dtype=torch.int64)
|
||||
sample_attention_mask = torch.zeros(
|
||||
[1, 1, 100, 100], dtype=torch.float32
|
||||
)
|
||||
num_group_layers = 1
|
||||
if "7b" in self.model_name:
|
||||
num_in_features = 4544
|
||||
if compressed:
|
||||
num_group_layers = 8
|
||||
elif "40b" in self.model_name:
|
||||
num_in_features = 8192
|
||||
if compressed:
|
||||
num_group_layers = 15
|
||||
else:
|
||||
num_in_features = 14848
|
||||
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
|
||||
if compressed:
|
||||
num_group_layers = 20
|
||||
|
||||
sample_hidden_states = torch.zeros(
|
||||
[1, 100, num_in_features], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Determine number of available devices
|
||||
num_devices = 1
|
||||
if self.device == "rocm":
|
||||
import iree.runtime as ireert
|
||||
|
||||
haldriver = ireert.get_driver(self.device)
|
||||
num_devices = len(haldriver.query_available_devices())
|
||||
|
||||
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
|
||||
print("Compiling Layer lm_head")
|
||||
shark_lm_head, _ = self.compile_layer(
|
||||
lm_head,
|
||||
[sample_hidden_states],
|
||||
"lm_head",
|
||||
device_idx=0 % num_devices if self.device == "rocm" else None,
|
||||
)
|
||||
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
|
||||
|
||||
word_embedding = WordEmbeddingsLayer(
|
||||
self.src_model.transformer.word_embeddings
|
||||
)
|
||||
print("Compiling Layer word_embeddings")
|
||||
shark_word_embedding, _ = self.compile_layer(
|
||||
word_embedding,
|
||||
[sample_input_ids],
|
||||
"word_embeddings",
|
||||
device_idx=1 % num_devices if self.device == "rocm" else None,
|
||||
)
|
||||
shark_word_embedding = CompiledWordEmbeddingsLayer(
|
||||
shark_word_embedding
|
||||
)
|
||||
|
||||
ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f)
|
||||
print("Compiling Layer ln_f")
|
||||
shark_ln_f, _ = self.compile_layer(
|
||||
ln_f,
|
||||
[sample_hidden_states],
|
||||
"ln_f",
|
||||
device_idx=2 % num_devices if self.device == "rocm" else None,
|
||||
)
|
||||
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
|
||||
|
||||
shark_layers = []
|
||||
for i in range(
|
||||
int(len(self.src_model.transformer.h) / num_group_layers)
|
||||
):
|
||||
device_idx = i % num_devices if self.device == "rocm" else None
|
||||
layer_id = i
|
||||
pytorch_class = DecoderLayer
|
||||
compiled_class = CompiledDecoderLayer
|
||||
if compressed:
|
||||
layer_id = (
|
||||
str(i * num_group_layers)
|
||||
+ "_"
|
||||
+ str((i + 1) * num_group_layers)
|
||||
)
|
||||
pytorch_class = EightDecoderLayer
|
||||
compiled_class = CompiledEightDecoderLayer
|
||||
|
||||
print("Compiling Layer {}".format(layer_id))
|
||||
if compressed:
|
||||
layer_i = self.src_model.transformer.h[
|
||||
i * num_group_layers : (i + 1) * num_group_layers
|
||||
]
|
||||
else:
|
||||
layer_i = self.src_model.transformer.h[i]
|
||||
|
||||
pytorch_layer_i = pytorch_class(
|
||||
layer_i, args.falcon_variant_to_use
|
||||
)
|
||||
shark_module, device_idx = self.compile_layer(
|
||||
pytorch_layer_i,
|
||||
[sample_hidden_states, sample_attention_mask],
|
||||
layer_id,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
del shark_module
|
||||
shark_layer_i = compiled_class(
|
||||
layer_id,
|
||||
device_idx,
|
||||
args.falcon_variant_to_use,
|
||||
self.device,
|
||||
self.precision,
|
||||
)
|
||||
shark_layers.append(shark_layer_i)
|
||||
|
||||
sharded_model = ShardedFalconModel(
|
||||
self.src_model,
|
||||
shark_layers,
|
||||
shark_word_embedding,
|
||||
shark_ln_f,
|
||||
shark_lm_head,
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
def generate(self, prompt):
|
||||
model_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.max_padding_length,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
model_inputs["prompt_text"] = prompt
|
||||
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs.get("attention_mask", None)
|
||||
|
||||
# Allow empty prompts
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
|
||||
generate_kwargs = {
|
||||
"max_length": self.max_num_tokens,
|
||||
"do_sample": True,
|
||||
"top_k": 10,
|
||||
"num_return_sequences": 1,
|
||||
"eos_token_id": 11,
|
||||
}
|
||||
generate_kwargs["input_ids"] = input_ids
|
||||
generate_kwargs["attention_mask"] = attention_mask
|
||||
generation_config_ = GenerationConfig.from_model_config(
|
||||
self.src_model.config
|
||||
)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
model_kwargs,
|
||||
) = self.src_model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs[
|
||||
"output_hidden_states"
|
||||
] = generation_config.output_hidden_states
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
self.logits_processor = self.src_model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids.shape[-1],
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
self.stopping_criteria = self.src_model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
self.logits_warper = self.src_model._get_logits_warper(
|
||||
generation_config
|
||||
)
|
||||
|
||||
(
|
||||
self.input_ids,
|
||||
self.model_kwargs,
|
||||
) = self.src_model._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id) if eos_token_id is not None else None
|
||||
)
|
||||
|
||||
self.pad_token_id = generation_config.pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
self.scores = (
|
||||
() if (return_dict_in_generate and output_scores) else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
self.unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
|
||||
all_text = prompt
|
||||
|
||||
start = time.time()
|
||||
count = 0
|
||||
for i in range(self.max_num_tokens - 1):
|
||||
count = count + 1
|
||||
|
||||
next_token = self.generate_new_token()
|
||||
new_word = self.tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
all_text = all_text + new_word
|
||||
|
||||
print(f"{new_word}", end="", flush=True)
|
||||
print(f"{all_text}", end="", flush=True)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if self.eos_token_id_tensor is not None:
|
||||
self.unfinished_sequences = self.unfinished_sequences.mul(
|
||||
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
|
||||
.ne(self.eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if (
|
||||
self.unfinished_sequences.max() == 0
|
||||
or self.stopping_criteria(input_ids, self.scores)
|
||||
):
|
||||
break
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / count
|
||||
)
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return all_text
|
||||
|
||||
def generate_new_token(self):
|
||||
model_inputs = self.src_model.prepare_inputs_for_generation(
|
||||
self.input_ids, **self.model_kwargs
|
||||
)
|
||||
outputs = self.shark_model.forward(
|
||||
input_ids=model_inputs["input_ids"],
|
||||
attention_mask=model_inputs["attention_mask"],
|
||||
)
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = self.logits_processor(
|
||||
self.input_ids, next_token_logits
|
||||
)
|
||||
next_token_scores = self.logits_warper(
|
||||
self.input_ids, next_token_scores
|
||||
)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if self.eos_token_id is not None:
|
||||
if self.pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = (
|
||||
next_token * self.unfinished_sequences
|
||||
+ self.pad_token_id * (1 - self.unfinished_sequences)
|
||||
)
|
||||
|
||||
self.input_ids = torch.cat(
|
||||
[self.input_ids, next_token[:, None]], dim=-1
|
||||
)
|
||||
|
||||
self.model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in self.model_kwargs:
|
||||
attention_mask = self.model_kwargs["attention_mask"]
|
||||
self.model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
self.input_ids = self.input_ids[:, 1:]
|
||||
self.model_kwargs["attention_mask"] = self.model_kwargs[
|
||||
"attention_mask"
|
||||
][:, 1:]
|
||||
|
||||
return next_token
|
||||
|
||||
|
||||
class UnshardedFalcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="tiiuae/falcon-7b-instruct",
|
||||
hf_auth_token: str = "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",
|
||||
max_num_tokens=150,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print("hf_model_path: ", self.hf_model_path)
|
||||
|
||||
if "180b" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
""" HF auth token required for falcon-180b. Pass it using
|
||||
@@ -119,9 +672,16 @@ class Falcon(SharkLLMBase):
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
if self.precision == "int4":
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
if self.precision == "int4":
|
||||
falcon_model = falcon_model.to(torch.float32)
|
||||
return falcon_model
|
||||
|
||||
def compile(self):
|
||||
@@ -173,8 +733,6 @@ class Falcon(SharkLLMBase):
|
||||
print(
|
||||
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
|
||||
if not mlir_generated:
|
||||
@@ -195,9 +753,10 @@ class Falcon(SharkLLMBase):
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
falconCompileInput,
|
||||
is_f16=self.precision == "fp16",
|
||||
is_f16=self.precision in ["fp16", "int4"],
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
is_gptq=self.precision == "int4",
|
||||
)
|
||||
del model
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
@@ -221,9 +780,12 @@ class Falcon(SharkLLMBase):
|
||||
f_.write(bytecode)
|
||||
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
|
||||
f_.close()
|
||||
del bytecode
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module=self.falcon_mlir_path,
|
||||
device=self.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
@@ -232,7 +794,12 @@ class Falcon(SharkLLMBase):
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
]
|
||||
+ [
|
||||
"--iree-llvmcpu-use-fast-min-max-ops",
|
||||
]
|
||||
if self.precision == "int4"
|
||||
else [],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
@@ -367,7 +934,11 @@ class Falcon(SharkLLMBase):
|
||||
|
||||
all_text = prompt
|
||||
|
||||
start = time.time()
|
||||
count = 0
|
||||
for i in range(self.max_num_tokens - 1):
|
||||
count = count + 1
|
||||
|
||||
next_token = self.generate_new_token()
|
||||
new_word = self.tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
@@ -394,6 +965,13 @@ class Falcon(SharkLLMBase):
|
||||
):
|
||||
break
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / count
|
||||
)
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@@ -409,7 +987,7 @@ class Falcon(SharkLLMBase):
|
||||
(model_inputs["input_ids"], model_inputs["attention_mask"]),
|
||||
)
|
||||
)
|
||||
if self.precision == "fp16":
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
@@ -488,23 +1066,39 @@ if __name__ == "__main__":
|
||||
else Path(args.falcon_vmfb_path)
|
||||
)
|
||||
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "tiiuae/falcon-180B-chat"
|
||||
if args.precision == "int4":
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"TheBloke/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct-GPTQ"
|
||||
)
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "tiiuae/falcon-180B-chat"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
|
||||
)
|
||||
|
||||
if not args.sharded:
|
||||
falcon = UnshardedFalcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
falcon_vmfb_path=falcon_vmfb_path,
|
||||
)
|
||||
else:
|
||||
falcon = ShardedFalcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
)
|
||||
|
||||
falcon = Falcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
falcon_vmfb_path=falcon_vmfb_path,
|
||||
)
|
||||
|
||||
import gc
|
||||
|
||||
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
|
||||
continue_execution = True
|
||||
@@ -524,7 +1118,11 @@ if __name__ == "__main__":
|
||||
prompt = input("Please enter the prompt text: ")
|
||||
print("\nPrompt Text: ", prompt)
|
||||
|
||||
res_str = falcon.generate(prompt)
|
||||
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
|
||||
User: {prompt}
|
||||
Assistant:"""
|
||||
|
||||
res_str = falcon.generate(prompt_template)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
print(
|
||||
|
||||
@@ -126,7 +126,7 @@ def is_url(input_url):
|
||||
import os
|
||||
import tempfile
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
@@ -235,6 +235,12 @@ def compile_int_precision(
|
||||
mlir_module = BytesIO(mlir_module)
|
||||
bytecode = mlir_module.read()
|
||||
print(f"Elided IR written for {extended_model_name}")
|
||||
bytecode = save_mlir(
|
||||
bytecode,
|
||||
model_name=extended_model_name,
|
||||
frontend="torch",
|
||||
dir=os.getcwd(),
|
||||
)
|
||||
return bytecode
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
|
||||
@@ -53,6 +53,7 @@ datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
datas += collect_data_files("cv2")
|
||||
datas += collect_data_files("einops")
|
||||
datas += [
|
||||
("src/utils/resources/prompts.json", "resources"),
|
||||
("src/utils/resources/model_db.json", "resources"),
|
||||
@@ -74,6 +75,9 @@ datas += [
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("diffusers") if "tests" not in x
|
||||
]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
|
||||
@@ -8,6 +8,7 @@ import traceback
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import requests
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_opt_flags,
|
||||
@@ -16,6 +17,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
preprocessCKPT,
|
||||
convert_original_vae,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
get_civitai_checkpoint,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
@@ -94,21 +96,19 @@ class SharkifyStableDiffusionModel:
|
||||
self.height = height // 8
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
self.custom_weights = custom_weights
|
||||
self.custom_weights = custom_weights.strip()
|
||||
self.use_quantize = use_quantize
|
||||
if custom_weights != "":
|
||||
if "civitai" in custom_weights:
|
||||
weights_id = custom_weights.split("/")[-1]
|
||||
# TODO: use model name and identify file type by civitai rest api
|
||||
weights_path = (
|
||||
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
|
||||
)
|
||||
if not os.path.isfile(weights_path):
|
||||
subprocess.run(
|
||||
["wget", custom_weights, "-O", weights_path]
|
||||
)
|
||||
if custom_weights.startswith("https://civitai.com/api/"):
|
||||
# download the checkpoint from civitai if we don't already have it
|
||||
weights_path = get_civitai_checkpoint(custom_weights)
|
||||
|
||||
# act as if we were given the local file as custom_weights originally
|
||||
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
|
||||
self.custom_weights = weights_path
|
||||
|
||||
# needed to ensure webui sets the correct model name metadata
|
||||
args.ckpt_loc = weights_path
|
||||
else:
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
@@ -116,6 +116,7 @@ class SharkifyStableDiffusionModel:
|
||||
custom_weights = get_path_to_diffusers_checkpoint(
|
||||
custom_weights
|
||||
)
|
||||
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
# TODO: remove the following line when stable-diffusion-2-1 works
|
||||
if self.model_id == "stabilityai/stable-diffusion-2-1":
|
||||
@@ -710,8 +711,11 @@ class SharkifyStableDiffusionModel:
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
|
||||
save_dir = ""
|
||||
if self.debug:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["clip"]
|
||||
)
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
|
||||
@@ -84,9 +84,6 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
|
||||
@@ -41,3 +41,4 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
resize_stencil,
|
||||
_compile_module,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint
|
||||
|
||||
42
apps/stable_diffusion/src/utils/civitai.py
Normal file
42
apps/stable_diffusion/src/utils/civitai.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import re
|
||||
import requests
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_civitai_checkpoint(url: str):
|
||||
with requests.get(url, allow_redirects=True, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# civitai api returns the filename in the content disposition
|
||||
base_filename = re.findall(
|
||||
'"([^"]*)"', response.headers["Content-Disposition"]
|
||||
)[0]
|
||||
destination_path = (
|
||||
Path.cwd() / (args.ckpt_dir or "models") / base_filename
|
||||
)
|
||||
|
||||
# we don't have this model downloaded yet
|
||||
if not destination_path.is_file():
|
||||
print(
|
||||
f"downloading civitai model from {url} to {destination_path}"
|
||||
)
|
||||
|
||||
size = int(response.headers["content-length"], 0)
|
||||
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=65536):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
# we already have this model downloaded
|
||||
else:
|
||||
print(f"civitai model already downloaded to {destination_path}")
|
||||
|
||||
response.close()
|
||||
return destination_path.as_posix()
|
||||
@@ -11,12 +11,12 @@
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,7 @@
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
},
|
||||
@@ -37,7 +37,7 @@
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -45,12 +45,12 @@
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,8 +203,8 @@ def dump_after_mlir(input_mlir, use_winograd):
|
||||
if use_winograd:
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
|
||||
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
|
||||
"iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
@@ -212,8 +212,8 @@ def dump_after_mlir(input_mlir, use_winograd):
|
||||
else:
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
|
||||
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
)
|
||||
|
||||
@@ -253,28 +253,30 @@ p.add_argument(
|
||||
"--left",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend left for outpainting.",
|
||||
help="If extend left for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--right",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend right for outpainting.",
|
||||
help="If extend right for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--up",
|
||||
"--top",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend top for outpainting.",
|
||||
help="If extend top for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--down",
|
||||
"--bottom",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend bottom for outpainting.",
|
||||
help="If extend bottom for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -306,7 +308,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Imports the model from torch module to shark_module otherwise "
|
||||
"downloads the model from shark_tank.",
|
||||
@@ -329,7 +331,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=True,
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available.",
|
||||
)
|
||||
@@ -422,7 +424,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--use_stencil",
|
||||
choices=["canny", "openpose", "scribble"],
|
||||
choices=["canny", "openpose", "scribble", "zoedepth"],
|
||||
help="Enable the stencil feature.",
|
||||
)
|
||||
|
||||
@@ -458,6 +460,14 @@ p.add_argument(
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device_allocator_heap_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify heap key for device caching allocator."
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
@@ -633,6 +643,18 @@ p.add_argument(
|
||||
help="Flag for enabling rest API.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--api_accept_origin",
|
||||
action="append",
|
||||
type=str,
|
||||
help="An origin to be accepted by the REST api for Cross Origin"
|
||||
"Resource Sharing (CORS). Use multiple times for multiple origins, "
|
||||
'or use --api_accept_origin="*" to accept all origins. If no origins '
|
||||
"are set no CORS headers will be returned by the api. Use, for "
|
||||
"instance, if you need to access the REST api from Javascript running "
|
||||
"in a web browser.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--debug",
|
||||
default=False,
|
||||
@@ -717,6 +739,17 @@ p.add_argument(
|
||||
help="Specifies whether the docuchat's web version is running or not.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# rocm Flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree_rocm_target_chip",
|
||||
type=str,
|
||||
default="gfx1100",
|
||||
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Default gfx1100",
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
if args.import_debug:
|
||||
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector
|
||||
from apps.stable_diffusion.src.utils.stencils.zoe import ZoeDetector
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch
|
||||
from apps.stable_diffusion.src.utils.stencils import (
|
||||
CannyDetector,
|
||||
OpenposeDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
|
||||
stencil = {}
|
||||
@@ -117,6 +118,9 @@ def controlnet_hint_conversion(
|
||||
case "scribble":
|
||||
print("Working with scribble")
|
||||
controlnet_hint = hint_scribble(image)
|
||||
case "zoedepth":
|
||||
print("Working with ZoeDepth")
|
||||
controlnet_hint = hint_zoedepth(image)
|
||||
case _:
|
||||
return None
|
||||
controlnet_hint = controlnet_hint_shaping(
|
||||
@@ -127,7 +131,7 @@ def controlnet_hint_conversion(
|
||||
|
||||
stencil_to_model_id_map = {
|
||||
"canny": "lllyasviel/control_v11p_sd15_canny",
|
||||
"depth": "lllyasviel/control_v11p_sd15_depth",
|
||||
"zoedepth": "lllyasviel/control_v11f1p_sd15_depth",
|
||||
"hed": "lllyasviel/sd-controlnet-hed",
|
||||
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
|
||||
"normal": "lllyasviel/control_v11p_sd15_normalbae",
|
||||
@@ -184,3 +188,16 @@ def hint_scribble(image: Image.Image):
|
||||
detected_map = np.zeros_like(input_image, dtype=np.uint8)
|
||||
detected_map[np.min(input_image, axis=2) < 127] = 255
|
||||
return detected_map
|
||||
|
||||
|
||||
# Stencil 4. Depth (Only Zoe Preprocessing)
|
||||
def hint_zoedepth(image: Image.Image):
|
||||
with torch.no_grad():
|
||||
input_image = np.array(image)
|
||||
|
||||
if not "depth" in stencil:
|
||||
stencil["depth"] = ZoeDetector()
|
||||
|
||||
detected_map = stencil["depth"](input_image)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
58
apps/stable_diffusion/src/utils/stencils/zoe/__init__.py
Normal file
58
apps/stable_diffusion/src/utils/stencils/zoe/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import requests
|
||||
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
remote_model_path = (
|
||||
"https://huggingface.co/lllyasviel/Annotators/resolve/main/ZoeD_M12_N.pt"
|
||||
)
|
||||
|
||||
|
||||
class ZoeDetector:
|
||||
def __init__(self):
|
||||
cwd = Path.cwd()
|
||||
ckpt_path = Path(cwd, "stencil_annotator")
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
modelpath = ckpt_path / "ZoeD_M12_N.pt"
|
||||
|
||||
with requests.get(remote_model_path, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(modelpath, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
model = torch.hub.load(
|
||||
"monorimet/ZoeDepth:torch_update",
|
||||
"ZoeD_N",
|
||||
pretrained=False,
|
||||
force_reload=False,
|
||||
)
|
||||
model.load_state_dict(
|
||||
torch.load(modelpath, map_location=model.device)["model"]
|
||||
)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
def __call__(self, input_image):
|
||||
assert input_image.ndim == 3
|
||||
image_depth = input_image
|
||||
with torch.no_grad():
|
||||
image_depth = torch.from_numpy(image_depth).float()
|
||||
image_depth = image_depth / 255.0
|
||||
image_depth = rearrange(image_depth, "h w c -> 1 c h w")
|
||||
depth = self.model.infer(image_depth)
|
||||
|
||||
depth = depth[0, 0].cpu().numpy()
|
||||
|
||||
vmin = np.percentile(depth, 2)
|
||||
vmax = np.percentile(depth, 85)
|
||||
|
||||
depth -= vmin
|
||||
depth /= vmax - vmin
|
||||
depth = 1.0 - depth
|
||||
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
return depth_image
|
||||
@@ -18,7 +18,7 @@ import tempfile
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
@@ -154,8 +154,8 @@ def compile_through_fx(
|
||||
f16_input_mask=f16_input_mask,
|
||||
debug=debug,
|
||||
model_name=extended_model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
|
||||
if use_tuned:
|
||||
if "vae" in extended_model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
@@ -168,6 +168,14 @@ def compile_through_fx(
|
||||
mlir_module, extended_model_name, base_model_id
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_dir):
|
||||
save_dir = ""
|
||||
|
||||
mlir_module = save_mlir(
|
||||
mlir_module,
|
||||
model_name=extended_model_name,
|
||||
dir=save_dir,
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device if device is None else device,
|
||||
@@ -179,17 +187,22 @@ def compile_through_fx(
|
||||
mlir_module,
|
||||
)
|
||||
|
||||
del mlir_module
|
||||
gc.collect()
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if args.device_allocator_heap_key:
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={args.device_allocator_heap_key}",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
@@ -464,7 +477,14 @@ def get_available_devices():
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
# for drivers with single devices
|
||||
# let the default device be selected without any indexing
|
||||
if len(device_list_dict) == 1:
|
||||
device_list.append(f"{device_name} => {driver_name}")
|
||||
else:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name}://{i}"
|
||||
)
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
@@ -522,10 +542,6 @@ def get_opt_flags(model, precision="fp16"):
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
)
|
||||
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"default_compilation_flags"
|
||||
@@ -795,11 +811,12 @@ def batch_seeds(
|
||||
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
|
||||
|
||||
if repeatable:
|
||||
# set seed for the rng based on what we have so far
|
||||
saved_random_state = random_getstate()
|
||||
if all(seed < 0 for seed in seeds):
|
||||
seeds[0] = sanitize_seed(seeds[0])
|
||||
seed_random(str(seeds))
|
||||
|
||||
# set seed for the rng based on what we have so far
|
||||
saved_random_state = random_getstate()
|
||||
seed_random(str([n for n in seeds if n > -1]))
|
||||
|
||||
# generate any seeds that are unspecified
|
||||
seeds = [sanitize_seed(seed) for seed in seeds]
|
||||
@@ -885,6 +902,13 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if args.write_metadata_to_png:
|
||||
# Using a conditional expression caused problems, so setting a new
|
||||
# variable for now.
|
||||
if args.use_hiresfix:
|
||||
png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}"
|
||||
else:
|
||||
png_size_text = f"{args.width}x{args.height}"
|
||||
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}"
|
||||
@@ -893,7 +917,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
f"Sampler: {args.scheduler}, "
|
||||
f"CFG scale: {args.guidance_scale}, "
|
||||
f"Seed: {img_seed},"
|
||||
f"Size: {args.width}x{args.height}, "
|
||||
f"Size: {png_size_text}, "
|
||||
f"Model: {img_model}, "
|
||||
f"VAE: {img_vae}, "
|
||||
f"LoRA: {img_lora}",
|
||||
@@ -920,8 +944,10 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"HEIGHT": args.height,
|
||||
"WIDTH": args.width,
|
||||
"HEIGHT": args.height
|
||||
if not args.use_hiresfix
|
||||
else args.hiresfix_height,
|
||||
"WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
"VAE": img_vae,
|
||||
@@ -959,6 +985,10 @@ def get_generation_text_info(seeds, device):
|
||||
)
|
||||
text_output += (
|
||||
f"\nsize={args.height}x{args.width}, "
|
||||
if not args.use_hiresfix
|
||||
else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, "
|
||||
)
|
||||
text_output += (
|
||||
f"batch_count={args.batch_count}, "
|
||||
f"batch_size={args.batch_size}, "
|
||||
f"max_length={args.max_length}"
|
||||
|
||||
1
apps/stable_diffusion/web/api/__init__.py
Normal file
1
apps/stable_diffusion/web/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from apps.stable_diffusion.web.api.sdapi_v1 import sdapi
|
||||
579
apps/stable_diffusion/web/api/sdapi_v1.py
Normal file
579
apps/stable_diffusion/web/api/sdapi_v1.py
Normal file
@@ -0,0 +1,579 @@
|
||||
import os
|
||||
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel, Field, conlist, model_validator
|
||||
|
||||
from apps.stable_diffusion.web.api.utils import (
|
||||
frozen_args,
|
||||
sampler_aliases,
|
||||
encode_pil_to_base64,
|
||||
decode_base64_to_image,
|
||||
get_model_from_request,
|
||||
get_scheduler_from_request,
|
||||
get_lora_params,
|
||||
get_device,
|
||||
GenerationInputData,
|
||||
GenerationResponseData,
|
||||
)
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_files,
|
||||
get_custom_model_pathfile,
|
||||
predefined_models,
|
||||
predefined_paint_models,
|
||||
predefined_upscaler_models,
|
||||
scheduler_list,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import txt2img_inf
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import img2img_inf
|
||||
from apps.stable_diffusion.web.ui.inpaint_ui import inpaint_inf
|
||||
from apps.stable_diffusion.web.ui.outpaint_ui import outpaint_inf
|
||||
from apps.stable_diffusion.web.ui.upscaler_ui import upscaler_inf
|
||||
|
||||
sdapi = FastAPI()
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/sd-models (lists available models)
|
||||
class AppParam(str, Enum):
|
||||
txt2img = "txt2img"
|
||||
img2img = "img2img"
|
||||
inpaint = "inpaint"
|
||||
outpaint = "outpaint"
|
||||
upscaler = "upscaler"
|
||||
|
||||
|
||||
@sdapi.get(
|
||||
"/v1/sd-models",
|
||||
summary="lists available models",
|
||||
description=(
|
||||
"This is all the models that this server currently knows about.\n "
|
||||
"Models listed may still have a compilation and build pending that "
|
||||
"will be triggered the first time they are used."
|
||||
),
|
||||
)
|
||||
def sd_models_api(app: AppParam = frozen_args.app):
|
||||
match app:
|
||||
case "inpaint" | "outpaint":
|
||||
checkpoint_type = "inpainting"
|
||||
predefined = predefined_paint_models
|
||||
case "upscaler":
|
||||
checkpoint_type = "upscaler"
|
||||
predefined = predefined_upscaler_models
|
||||
case _:
|
||||
checkpoint_type = ""
|
||||
predefined = predefined_models
|
||||
|
||||
return [
|
||||
{
|
||||
"title": model_file,
|
||||
"model_name": model_file,
|
||||
"hash": None,
|
||||
"sha256": None,
|
||||
"filename": get_custom_model_pathfile(model_file),
|
||||
"config": None,
|
||||
}
|
||||
for model_file in get_custom_model_files(
|
||||
custom_checkpoint_type=checkpoint_type
|
||||
)
|
||||
] + [
|
||||
{
|
||||
"title": model,
|
||||
"model_name": model,
|
||||
"hash": None,
|
||||
"sha256": None,
|
||||
"filename": None,
|
||||
"config": None,
|
||||
}
|
||||
for model in predefined
|
||||
]
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/samplers (lists schedulers)
|
||||
@sdapi.get(
|
||||
"/v1/samplers",
|
||||
summary="lists available schedulers/samplers",
|
||||
description=(
|
||||
"These are all the Schedulers defined and available. Not "
|
||||
"every scheduler is compatible with all apis. Aliases are "
|
||||
"equivalent samplers in A1111 if they are known."
|
||||
),
|
||||
)
|
||||
def sd_samplers_api():
|
||||
reverse_sampler_aliases = defaultdict(list)
|
||||
for key, value in sampler_aliases.items():
|
||||
reverse_sampler_aliases[value].append(key)
|
||||
|
||||
return (
|
||||
{
|
||||
"name": scheduler,
|
||||
"aliases": reverse_sampler_aliases.get(scheduler, []),
|
||||
"options": {},
|
||||
}
|
||||
for scheduler in scheduler_list
|
||||
)
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/options (lists application level options)
|
||||
@sdapi.get(
|
||||
"/v1/options",
|
||||
summary="lists current settings of application level options",
|
||||
description=(
|
||||
"A subset of the command line arguments set at startup renamed "
|
||||
"to correspond to the A1111 naming. Only a small subset of A1111 "
|
||||
"options are returned."
|
||||
),
|
||||
)
|
||||
def options_api():
|
||||
# This is mostly just enough to support what Koboldcpp wants, with a
|
||||
# few other things that seemed obvious
|
||||
return {
|
||||
"samples_save": True,
|
||||
"samples_format": frozen_args.output_img_format,
|
||||
"sd_model_checkpoint": os.path.basename(frozen_args.ckpt_loc)
|
||||
if frozen_args.ckpt_loc
|
||||
else frozen_args.hf_model_id,
|
||||
"sd_lora": frozen_args.use_lora,
|
||||
"sd_vae": frozen_args.custom_vae or "Automatic",
|
||||
"enable_pnginfo": frozen_args.write_metadata_to_png,
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/cmd-flags (lists command line argument settings)
|
||||
@sdapi.get(
|
||||
"/v1/cmd-flags",
|
||||
summary="lists the command line arguments value that were set on startup.",
|
||||
)
|
||||
def cmd_flags_api():
|
||||
return vars(frozen_args)
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/txt2img (Text to image)
|
||||
class ModelOverrideSettings(BaseModel):
|
||||
sd_model_checkpoint: str = get_model_from_request(
|
||||
fallback_model="stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
|
||||
|
||||
class Txt2ImgInputData(GenerationInputData):
|
||||
enable_hr: bool = frozen_args.use_hiresfix
|
||||
hr_resize_y: int = Field(
|
||||
default=frozen_args.hiresfix_height, ge=128, le=768, multiple_of=8
|
||||
)
|
||||
hr_resize_x: int = Field(
|
||||
default=frozen_args.hiresfix_width, ge=128, le=768, multiple_of=8
|
||||
)
|
||||
override_settings: ModelOverrideSettings = None
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/txt2img",
|
||||
summary="Does text to image generation",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def txt2img_api(InputData: Txt2ImgInputData):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
fallback_model="stabilityai/stable-diffusion-2-1-base",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(
|
||||
InputData, "txt2img_hires" if InputData.enable_hr else "txt2img"
|
||||
)
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed},"
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}. "
|
||||
)
|
||||
|
||||
res = txt2img_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
use_hiresfix=InputData.enable_hr,
|
||||
hiresfix_height=InputData.hr_resize_y,
|
||||
hiresfix_width=InputData.hr_resize_x,
|
||||
hiresfix_strength=frozen_args.hiresfix_strength,
|
||||
resample_type=frozen_args.resample_type,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/img2img (Image to image)
|
||||
class StencilParam(str, Enum):
|
||||
canny = "canny"
|
||||
openpose = "openpose"
|
||||
scribble = "scribble"
|
||||
zoedepth = "zoedepth"
|
||||
|
||||
|
||||
class Img2ImgInputData(GenerationInputData):
|
||||
init_images: conlist(str, min_length=1, max_length=2)
|
||||
denoising_strength: float = frozen_args.strength
|
||||
use_stencil: StencilParam = frozen_args.use_stencil
|
||||
override_settings: ModelOverrideSettings = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_image_supplied_for_scribble_stencil(self) -> "Img2ImgInputData":
|
||||
if (
|
||||
self.use_stencil == StencilParam.scribble
|
||||
and len(self.init_images) < 2
|
||||
):
|
||||
raise ValueError(
|
||||
"a second image must be supplied for the controlnet:scribble stencil"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/img2img",
|
||||
summary="Does image to image generation",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def img2img_api(
|
||||
InputData: Img2ImgInputData,
|
||||
):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
fallback_model="stabilityai/stable-diffusion-2-1-base",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "img2img")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
mask_image = (
|
||||
decode_base64_to_image(InputData.init_images[1])
|
||||
if len(InputData.init_images) > 1
|
||||
else None
|
||||
)
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed}, "
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}."
|
||||
)
|
||||
|
||||
res = img2img_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
{"image": init_image, "mask": mask_image},
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
InputData.denoising_strength,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
use_stencil=InputData.use_stencil,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
resample_type=frozen_args.resample_type,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/inpaint (Inpainting)
|
||||
class PaintModelOverideSettings(BaseModel):
|
||||
sd_model_checkpoint: str = get_model_from_request(
|
||||
checkpoint_type="inpainting",
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
|
||||
|
||||
class InpaintInputData(GenerationInputData):
|
||||
image: str = Field(description="Base64 encoded input image")
|
||||
mask: str = Field(description="Base64 encoded mask image")
|
||||
is_full_res: bool = False # Is this setting backwards in the UI?
|
||||
full_res_padding: int = Field(default=32, ge=0, le=256, multiple_of=4)
|
||||
denoising_strength: float = frozen_args.strength
|
||||
use_stencil: StencilParam = frozen_args.use_stencil
|
||||
override_settings: PaintModelOverideSettings = None
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/inpaint",
|
||||
summary="Does inpainting generation on an image",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def inpaint_api(
|
||||
InputData: InpaintInputData,
|
||||
):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
checkpoint_type="inpainting",
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "inpaint")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.image)
|
||||
mask = decode_base64_to_image(InputData.mask)
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f'Negative Prompt: {InputData.negative_prompt}", '
|
||||
f'Seed: {InputData.seed}", '
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}."
|
||||
)
|
||||
|
||||
res = inpaint_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
{"image": init_image, "mask": mask},
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.is_full_res,
|
||||
InputData.full_res_padding,
|
||||
InputData.steps,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/outpaint (Outpainting)
|
||||
class DirectionParam(str, Enum):
|
||||
left = "left"
|
||||
right = "right"
|
||||
up = "up"
|
||||
down = "down"
|
||||
|
||||
|
||||
class OutpaintInputData(GenerationInputData):
|
||||
init_images: list[str]
|
||||
pixels: int = Field(
|
||||
default=frozen_args.pixels, ge=8, le=256, multiple_of=8
|
||||
)
|
||||
mask_blur: int = Field(default=frozen_args.mask_blur, ge=0, le=64)
|
||||
directions: set[DirectionParam] = [
|
||||
direction
|
||||
for direction in ["left", "right", "up", "down"]
|
||||
if vars(frozen_args)[direction]
|
||||
]
|
||||
noise_q: float = frozen_args.noise_q
|
||||
color_variation: float = frozen_args.color_variation
|
||||
override_settings: PaintModelOverideSettings = None
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/outpaint",
|
||||
summary="Does outpainting generation on an image",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def outpaint_api(
|
||||
InputData: OutpaintInputData,
|
||||
):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
checkpoint_type="inpainting",
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "outpaint")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed}, "
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}."
|
||||
)
|
||||
|
||||
res = outpaint_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
init_image,
|
||||
InputData.pixels,
|
||||
InputData.mask_blur,
|
||||
InputData.directions,
|
||||
InputData.noise_q,
|
||||
InputData.color_variation,
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/upscaler (Upscaling)
|
||||
class UpscalerModelOverideSettings(BaseModel):
|
||||
sd_model_checkpoint: str = get_model_from_request(
|
||||
checkpoint_type="upscaler",
|
||||
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
|
||||
)
|
||||
|
||||
|
||||
class UpscalerInputData(GenerationInputData):
|
||||
init_images: list[str] = Field(
|
||||
description="Base64 encoded image to upscale"
|
||||
)
|
||||
noise_level: int = frozen_args.noise_level
|
||||
override_settings: UpscalerModelOverideSettings = None
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/upscaler",
|
||||
summary="Does image upscaling",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def upscaler_api(
|
||||
InputData: UpscalerInputData,
|
||||
):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
checkpoint_type="upscaler",
|
||||
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "upscaler")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed}, "
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}."
|
||||
)
|
||||
|
||||
res = upscaler_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
init_image,
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
InputData.noise_level,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
211
apps/stable_diffusion/web/api/utils.py
Normal file
211
apps/stable_diffusion/web/api/utils.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import base64
|
||||
import pickle
|
||||
|
||||
from argparse import Namespace
|
||||
from fastapi.exceptions import HTTPException
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
get_custom_model_files,
|
||||
predefined_models,
|
||||
predefined_paint_models,
|
||||
predefined_upscaler_models,
|
||||
scheduler_list,
|
||||
scheduler_list_cpu_only,
|
||||
)
|
||||
|
||||
|
||||
# Probably overly cautious, but try to ensure we only use the starting
|
||||
# args in each api call, as the code does `args.<whatever> = <changed_value>`
|
||||
# in lots of places and in testing, it seemed to me, these changes leaked
|
||||
# into subsequent api calls.
|
||||
|
||||
# Roundtripping through pickle for deepcopy, there is probably a better way
|
||||
frozen_args = Namespace(**(pickle.loads(pickle.dumps(vars(args)))))
|
||||
|
||||
# an attempt to map some of the A1111 sampler names to scheduler names
|
||||
# https://github.com/huggingface/diffusers/issues/4167 is where the
|
||||
# (not so obvious) ones come from
|
||||
sampler_aliases = {
|
||||
# a1111/onnx (these point to diffusers classes in A1111)
|
||||
"pndm": "PNDM",
|
||||
"heun": "HeunDiscrete",
|
||||
"ddim": "DDIM",
|
||||
"ddpm": "DDPM",
|
||||
"euler": "EulerDiscrete",
|
||||
"euler-ancestral": "EulerAncestralDiscrete",
|
||||
"dpm": "DPMSolverMultistep",
|
||||
# a1111/k_diffusion (the obvious ones)
|
||||
"Euler a": "EulerAncestralDiscrete",
|
||||
"Euler": "EulerDiscrete",
|
||||
"LMS": "LMSDiscrete",
|
||||
"Heun": "HeunDiscrete",
|
||||
# a1111/k_diffusion (not so obvious)
|
||||
"DPM++ 2M": "DPMSolverMultistep",
|
||||
"DPM++ 2M Karras": "DPMSolverMultistepKarras",
|
||||
"DPM++ 2M SDE": "DPMSolverMultistep++",
|
||||
"DPM++ 2M SDE Karras": "DPMSolverMultistepKarras++",
|
||||
"DPM2": "KDPM2Discrete",
|
||||
"DPM2 a": "KDPM2AncestralDiscrete",
|
||||
}
|
||||
|
||||
allowed_schedulers = {
|
||||
"txt2img": {
|
||||
"schedulers": scheduler_list,
|
||||
"fallback": "SharkEulerDiscrete",
|
||||
},
|
||||
"txt2img_hires": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "DEISMultistep",
|
||||
},
|
||||
"img2img": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "EulerDiscrete",
|
||||
},
|
||||
"inpaint": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "DDIM",
|
||||
},
|
||||
"outpaint": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "DDIM",
|
||||
},
|
||||
"upscaler": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "DDIM",
|
||||
},
|
||||
}
|
||||
|
||||
# base pydantic model for sd generation apis
|
||||
|
||||
|
||||
class GenerationInputData(BaseModel):
|
||||
prompt: str = ""
|
||||
negative_prompt: str = ""
|
||||
hf_model_id: str | None = None
|
||||
height: int = Field(
|
||||
default=frozen_args.height, ge=128, le=768, multiple_of=8
|
||||
)
|
||||
width: int = Field(
|
||||
default=frozen_args.width, ge=128, le=768, multiple_of=8
|
||||
)
|
||||
sampler_name: str = frozen_args.scheduler
|
||||
cfg_scale: float = Field(default=frozen_args.guidance_scale, ge=1)
|
||||
steps: int = Field(default=frozen_args.steps, ge=1, le=100)
|
||||
seed: int = frozen_args.seed
|
||||
n_iter: int = Field(default=frozen_args.batch_count)
|
||||
|
||||
|
||||
class GenerationResponseData(BaseModel):
|
||||
images: list[str] = Field(description="Generated images, Base64 encoded")
|
||||
properties: dict = {}
|
||||
info: str
|
||||
|
||||
|
||||
# image encoding/decoding
|
||||
|
||||
|
||||
def encode_pil_to_base64(images: list[Image.Image]):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if frozen_args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif frozen_args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding: str):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=400, detail="Invalid encoded image")
|
||||
|
||||
|
||||
# get valid sd models/vaes/schedulers etc.
|
||||
|
||||
|
||||
def get_predefined_models(custom_checkpoint_type: str):
|
||||
match custom_checkpoint_type:
|
||||
case "inpainting":
|
||||
return predefined_paint_models
|
||||
case "upscaler":
|
||||
return predefined_upscaler_models
|
||||
case _:
|
||||
return predefined_models
|
||||
|
||||
|
||||
def get_model_from_request(
|
||||
request_data=None,
|
||||
checkpoint_type: str = "",
|
||||
fallback_model: str = "",
|
||||
):
|
||||
model = None
|
||||
if request_data:
|
||||
if request_data.hf_model_id:
|
||||
model = request_data.hf_model_id
|
||||
elif request_data.override_settings:
|
||||
model = request_data.override_settings.sd_model_checkpoint
|
||||
|
||||
# if the request didn't specify a model try the command line args
|
||||
result = model or frozen_args.ckpt_loc or frozen_args.hf_model_id
|
||||
|
||||
# make sure whatever we have is a valid model for the checkpoint type
|
||||
if result in get_custom_model_files(
|
||||
custom_checkpoint_type=checkpoint_type
|
||||
) + get_predefined_models(checkpoint_type):
|
||||
return result
|
||||
# if not return what was specified as the fallback
|
||||
else:
|
||||
return fallback_model
|
||||
|
||||
|
||||
def get_scheduler_from_request(
|
||||
request_data: GenerationInputData, operation: str
|
||||
):
|
||||
allowed = allowed_schedulers[operation]
|
||||
|
||||
requested = request_data.sampler_name
|
||||
requested = sampler_aliases.get(requested, requested)
|
||||
|
||||
return (
|
||||
requested
|
||||
if requested in allowed["schedulers"]
|
||||
else allowed["fallback"]
|
||||
)
|
||||
|
||||
|
||||
def get_lora_params(use_lora: str):
|
||||
# TODO: since the inference functions in the webui, which we are
|
||||
# still calling into for the api, jam these back together again before
|
||||
# handing them off to the pipeline, we should remove this nonsense
|
||||
# and unify their selection in the UI and command line args proper
|
||||
if use_lora in get_custom_model_files("lora"):
|
||||
return (use_lora, "")
|
||||
|
||||
return ("None", use_lora)
|
||||
|
||||
|
||||
def get_device(device_str: str):
|
||||
# first substring match in the list available devices, with first
|
||||
# device when none are matched
|
||||
return next(
|
||||
(device for device in available_devices if device_str in device),
|
||||
available_devices[0],
|
||||
)
|
||||
@@ -1,7 +1,8 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
from multiprocessing import freeze_support
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import apps.stable_diffusion.web.utils.app as app
|
||||
|
||||
if sys.platform == "darwin":
|
||||
# import before IREE to avoid torch-MLIR library issues
|
||||
@@ -21,26 +22,6 @@ if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
|
||||
def launch_app(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
|
||||
window = Tk()
|
||||
|
||||
# get screen width and height of display and make it more reasonably
|
||||
# sized as we aren't making it full-screen or maximized
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.debug:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -48,39 +29,47 @@ if __name__ == "__main__":
|
||||
freeze_support()
|
||||
if args.api or "api" in args.ui.split(","):
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_api,
|
||||
img2img_api,
|
||||
upscaler_api,
|
||||
inpaint_api,
|
||||
outpaint_api,
|
||||
llm_chat_api,
|
||||
)
|
||||
from apps.stable_diffusion.web.api import sdapi
|
||||
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
# init global sd pipeline and config
|
||||
global_obj._init()
|
||||
|
||||
app = FastAPI()
|
||||
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
|
||||
api = FastAPI()
|
||||
api.mount("/sdapi/", sdapi)
|
||||
|
||||
# chat APIs needed for compatibility with multiple extensions using OpenAI API
|
||||
app.add_api_route(
|
||||
api.add_api_route(
|
||||
"/v1/chat/completions", llm_chat_api, methods=["post"]
|
||||
)
|
||||
app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
|
||||
app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
|
||||
app.add_api_route("/completions", llm_chat_api, methods=["post"])
|
||||
app.add_api_route(
|
||||
api.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
|
||||
api.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
|
||||
api.add_api_route("/completions", llm_chat_api, methods=["post"])
|
||||
api.add_api_route(
|
||||
"/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
|
||||
)
|
||||
app.include_router(APIRouter())
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.server_port)
|
||||
api.include_router(APIRouter())
|
||||
|
||||
# deal with CORS requests if CORS accept origins are set
|
||||
if args.api_accept_origin:
|
||||
print(
|
||||
f"API Configured for CORS. Accepting origins: { args.api_accept_origin }"
|
||||
)
|
||||
api.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.api_accept_origin,
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
else:
|
||||
print("API not configured for CORS")
|
||||
|
||||
uvicorn.run(api, host="0.0.0.0", port=args.server_port)
|
||||
sys.exit(0)
|
||||
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
@@ -94,7 +83,10 @@ if __name__ == "__main__":
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
create_custom_models_folders,
|
||||
nodicon_loc,
|
||||
)
|
||||
|
||||
create_custom_models_folders()
|
||||
|
||||
@@ -110,7 +102,6 @@ if __name__ == "__main__":
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
@@ -122,7 +113,6 @@ if __name__ == "__main__":
|
||||
# h2ogpt_web,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
@@ -131,7 +121,6 @@ if __name__ == "__main__":
|
||||
img2img_sendto_upscaler,
|
||||
inpaint_web,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
@@ -140,7 +129,6 @@ if __name__ == "__main__":
|
||||
inpaint_sendto_upscaler,
|
||||
outpaint_web,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
@@ -149,16 +137,15 @@ if __name__ == "__main__":
|
||||
outpaint_sendto_upscaler,
|
||||
upscaler_web,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
lora_train_web,
|
||||
model_web,
|
||||
model_config_web,
|
||||
# lora_train_web,
|
||||
# model_web,
|
||||
# model_config_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
@@ -213,7 +200,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
|
||||
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
|
||||
) as sd_web:
|
||||
with gr.Tabs() as tabs:
|
||||
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
|
||||
@@ -250,16 +237,16 @@ if __name__ == "__main__":
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
with gr.TabItem(label="Model Manager", id=6):
|
||||
model_web.render()
|
||||
with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot (Experimental)", id=8):
|
||||
# with gr.TabItem(label="Model Manager", id=6):
|
||||
# model_web.render()
|
||||
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
# lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=8):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(
|
||||
label="Generate Sharding Config (Experimental)", id=9
|
||||
):
|
||||
model_config_web.render()
|
||||
# with gr.TabItem(
|
||||
# label="Generate Sharding Config (Experimental)", id=9
|
||||
# ):
|
||||
# model_config_web.render()
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
@@ -267,6 +254,15 @@ if __name__ == "__main__":
|
||||
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
|
||||
# h2ogpt_web.render()
|
||||
|
||||
actual_port = app.usable_port()
|
||||
if actual_port != args.server_port:
|
||||
sd_web.load(
|
||||
fn=lambda: gr.Info(
|
||||
f"Port {args.server_port} is in use by another application. "
|
||||
f"Shark is running on port {actual_port} instead."
|
||||
)
|
||||
)
|
||||
|
||||
# send to buttons
|
||||
register_button_click(
|
||||
txt2img_sendto_img2img,
|
||||
@@ -399,42 +395,38 @@ if __name__ == "__main__":
|
||||
modelmanager_sendto_txt2img,
|
||||
0,
|
||||
[hf_models],
|
||||
[txt2img_custom_model, txt2img_hf_model_id, tabs],
|
||||
[txt2img_custom_model, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_img2img,
|
||||
1,
|
||||
[hf_models],
|
||||
[img2img_custom_model, img2img_hf_model_id, tabs],
|
||||
[img2img_custom_model, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_inpaint,
|
||||
2,
|
||||
[hf_models],
|
||||
[inpaint_custom_model, inpaint_hf_model_id, tabs],
|
||||
[inpaint_custom_model, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_outpaint,
|
||||
3,
|
||||
[hf_models],
|
||||
[outpaint_custom_model, outpaint_hf_model_id, tabs],
|
||||
[outpaint_custom_model, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_upscaler,
|
||||
4,
|
||||
[hf_models],
|
||||
[upscaler_custom_model, upscaler_hf_model_id, tabs],
|
||||
[upscaler_custom_model, tabs],
|
||||
)
|
||||
|
||||
sd_web.queue()
|
||||
if args.ui == "app":
|
||||
t = Process(
|
||||
target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
)
|
||||
t.start()
|
||||
sd_web.launch(
|
||||
share=args.share,
|
||||
inbrowser=args.ui == "web",
|
||||
inbrowser=not app.launch(actual_port),
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.server_port,
|
||||
server_port=actual_port,
|
||||
favicon_path=nodicon_loc,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
txt2img_inf,
|
||||
txt2img_api,
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
@@ -14,10 +12,8 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_inf,
|
||||
img2img_api,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
@@ -27,10 +23,8 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
inpaint_inf,
|
||||
inpaint_api,
|
||||
inpaint_web,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
@@ -40,10 +34,8 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
outpaint_inf,
|
||||
outpaint_api,
|
||||
outpaint_web,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
@@ -53,10 +45,8 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.upscaler_ui import (
|
||||
upscaler_inf,
|
||||
upscaler_api,
|
||||
upscaler_web,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
|
||||
@@ -212,6 +212,7 @@ with gr.Blocks(title="DocuChat") as h2ogpt_web:
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
|
||||
@@ -5,9 +5,6 @@ import gradio as gr
|
||||
import PIL
|
||||
from math import ceil
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -55,8 +52,7 @@ def img2img_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -103,21 +99,17 @@ def img2img_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files():
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -282,88 +274,6 @@ def img2img_inf(
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Img2Img Rest API.
|
||||
def img2img_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}.'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = img2img_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
init_image,
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["denoising_strength"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
use_stencil=InputData["use_stencil"]
|
||||
if "use_stencil" in InputData.keys()
|
||||
else "None",
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
resample_type="Lanczos",
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -382,31 +292,19 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
i2i_model_info = (str(get_custom_model_path())).replace(
|
||||
"\\", "\n\\"
|
||||
i2i_model_info = (
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
)
|
||||
i2i_model_info = f"Custom Model Path: {i2i_model_info}"
|
||||
img2img_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=i2i_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
img2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: SG161222/Realistic_Vision_V1.3, "
|
||||
"https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
choices=get_custom_model_files() + predefined_models,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
i2i_vae_info = (str(get_custom_model_path("vae"))).replace(
|
||||
@@ -421,6 +319,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -451,7 +351,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
elem_id="stencil_model",
|
||||
label="Stencil model",
|
||||
value="None",
|
||||
choices=["None", "canny", "openpose", "scribble"],
|
||||
choices=[
|
||||
"None",
|
||||
"canny",
|
||||
"openpose",
|
||||
"scribble",
|
||||
"zoedepth",
|
||||
],
|
||||
)
|
||||
|
||||
def show_canvas(choice):
|
||||
@@ -512,6 +418,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
).replace("\\", "\n\\")
|
||||
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=i2i_lora_info,
|
||||
elem_id="lora_weights",
|
||||
@@ -535,6 +442,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -590,6 +498,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
"Cubic",
|
||||
],
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
@@ -648,17 +557,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -670,13 +570,26 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{i2i_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
img2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
img2img_sendto_outpaint = gr.Button(
|
||||
@@ -702,7 +615,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
|
||||
@@ -4,9 +4,6 @@ import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -53,8 +50,7 @@ def inpaint_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -89,21 +85,17 @@ def inpaint_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"):
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -228,86 +220,6 @@ def inpaint_inf(
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Inpaint Rest API.
|
||||
def inpaint_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}.'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["image"])
|
||||
mask = decode_base64_to_image(InputData["mask"])
|
||||
res = inpaint_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
{"image": init_image, "mask": mask},
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["is_full_res"],
|
||||
InputData["full_res_padding"],
|
||||
InputData["steps"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -327,34 +239,21 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
inpaint_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
inpaint_model_info = (
|
||||
f"Custom Model Path: {inpaint_model_info}"
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
)
|
||||
inpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=inpaint_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files(
|
||||
choices=get_custom_model_files(
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
inpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
|
||||
"https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
inpaint_vae_info = (
|
||||
@@ -369,6 +268,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -406,6 +307,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -424,6 +326,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -527,17 +430,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -549,14 +443,26 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{inpaint_model_info}\n"
|
||||
"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
inpaint_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
inpaint_sendto_outpaint = gr.Button(
|
||||
@@ -583,7 +489,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
|
||||
BIN
apps/stable_diffusion/web/ui/logos/nod-icon.png
Normal file
BIN
apps/stable_diffusion/web/ui/logos/nod-icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
@@ -50,6 +50,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
@@ -73,6 +74,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -105,6 +107,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
@@ -177,6 +180,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
|
||||
@@ -143,6 +143,7 @@ with gr.Blocks() as minigpt4_web:
|
||||
# else "Only CUDA Supported for now",
|
||||
choices=["cuda"],
|
||||
interactive=False,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
|
||||
@@ -98,6 +98,7 @@ with gr.Blocks() as model_web:
|
||||
choices=None,
|
||||
value=None,
|
||||
visible=False,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
# TODO: select and SendTo
|
||||
civit_models = gr.Gallery(
|
||||
|
||||
@@ -53,8 +53,7 @@ def outpaint_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -88,21 +87,17 @@ def outpaint_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"):
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -233,88 +228,6 @@ def outpaint_inf(
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Inpaint Rest API.
|
||||
def outpaint_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = outpaint_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
init_image,
|
||||
InputData["pixels"],
|
||||
InputData["mask_blur"],
|
||||
InputData["directions"],
|
||||
InputData["noise_q"],
|
||||
InputData["color_variation"],
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -332,36 +245,22 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
outpaint_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
outpaint_model_info = (
|
||||
f"Custom Model Path: {outpaint_model_info}"
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
)
|
||||
outpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=outpaint_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files(
|
||||
choices=get_custom_model_files(
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
outpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
|
||||
"https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
outpaint_vae_info = (
|
||||
@@ -376,8 +275,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
@@ -411,6 +311,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -429,6 +330,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -555,17 +457,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -577,13 +470,26 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{outpaint_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
outpaint_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -611,7 +517,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
|
||||
@@ -109,6 +109,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
value="",
|
||||
interactive=True,
|
||||
elem_classes="dropdown_no_container",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column(
|
||||
scale=1,
|
||||
|
||||
@@ -32,42 +32,47 @@ model_map = {
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2_7b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_13b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
"A chat between a curious user and an artificial intelligence "
|
||||
"assistant. The assistant gives helpful, detailed, and "
|
||||
"polite answers to the user's questions.\n"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
def create_prompt(model_name, history, prompt_prefix):
|
||||
system_message = ""
|
||||
if prompt_prefix:
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if "llama2" in model_name:
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
@@ -75,7 +80,10 @@ def create_prompt(model_name, history):
|
||||
conversation = "".join(
|
||||
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
|
||||
)
|
||||
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
if prompt_prefix:
|
||||
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
else:
|
||||
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
elif model_name in ["vicuna"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
@@ -129,7 +137,7 @@ model_vmfb_key = ""
|
||||
|
||||
# TODO: Make chat reusable for UI and API
|
||||
def chat(
|
||||
curr_system_message,
|
||||
prompt_prefix,
|
||||
history,
|
||||
model,
|
||||
device,
|
||||
@@ -145,6 +153,7 @@ def chat(
|
||||
|
||||
device_id = None
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
device = device if "=>" not in device else device.split("=>")[1].strip()
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
@@ -156,6 +165,8 @@ def chat(
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
elif "metal" in device:
|
||||
device = "metal"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
@@ -208,8 +219,14 @@ def chat(
|
||||
assert (
|
||||
device_id
|
||||
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
print(f"Will use vulkan target triple : {vulkan_target_triple}")
|
||||
|
||||
print(f"Will use target triple : {vulkan_target_triple}")
|
||||
elif "rocm" in device:
|
||||
# add iree rocm flags
|
||||
_extra_args.append(
|
||||
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
|
||||
)
|
||||
print(f"extra args = {_extra_args}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
@@ -240,7 +257,7 @@ def chat(
|
||||
if vicuna_model is None:
|
||||
sys.exit("Unable to instantiate the model object, exiting.")
|
||||
|
||||
prompt = create_prompt(model_name, history)
|
||||
prompt = create_prompt(model_name, history, prompt_prefix)
|
||||
|
||||
partial_text = ""
|
||||
token_count = 0
|
||||
@@ -317,6 +334,8 @@ def llm_chat_api(InputData: dict):
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "metal" in device:
|
||||
device = "metal"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
@@ -393,6 +412,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
label="Select Model",
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
supported_devices = available_devices
|
||||
enabled = len(supported_devices) > 0
|
||||
@@ -406,6 +426,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
allow_custom_value=True,
|
||||
# multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
@@ -419,11 +440,17 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
visible=False,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
prompt_prefix = gr.Checkbox(
|
||||
label="Add System Prompt",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
@@ -450,9 +477,6 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
system_msg = gr.Textbox(
|
||||
start_message, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user,
|
||||
@@ -463,7 +487,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[
|
||||
system_msg,
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
@@ -484,7 +508,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[
|
||||
system_msg,
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
|
||||
@@ -5,9 +5,6 @@ import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from math import ceil
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -52,8 +49,7 @@ def txt2img_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -91,21 +87,17 @@ def txt2img_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files():
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -145,6 +137,11 @@ def txt2img_inf(
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.use_hiresfix = use_hiresfix
|
||||
args.hiresfix_height = hiresfix_height
|
||||
args.hiresfix_width = hiresfix_width
|
||||
args.hiresfix_strength = hiresfix_strength
|
||||
args.resample_type = resample_type
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.iree_metal_target_platform = init_iree_metal_target_platform
|
||||
@@ -301,75 +298,6 @@ def txt2img_inf(
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Text2Img Rest API.
|
||||
def txt2img_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}.'
|
||||
)
|
||||
res = txt2img_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
use_hiresfix=False,
|
||||
hiresfix_height=512,
|
||||
hiresfix_width=512,
|
||||
hiresfix_strength=0.6,
|
||||
resample_type="Nearest Neighbor",
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -389,32 +317,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
t2i_model_info = (
|
||||
f"Custom Model Path: {t2i_model_info}"
|
||||
)
|
||||
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
txt2img_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=t2i_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
choices=get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
txt2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the dropdown "
|
||||
"on the left and enter model ID here.",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL.",
|
||||
lines=3,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
t2i_vae_info = (
|
||||
@@ -430,6 +344,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
txt2img_png_info_img = gr.Image(
|
||||
@@ -466,6 +382,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -484,6 +401,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -568,6 +486,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
"Cubic",
|
||||
],
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
hiresfix_height = gr.Slider(
|
||||
384,
|
||||
@@ -624,6 +543,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
@@ -643,7 +563,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{t2i_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
@@ -686,7 +607,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -736,7 +656,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
width,
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
@@ -752,7 +671,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
width,
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
|
||||
@@ -3,9 +3,6 @@ import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -46,8 +43,7 @@ def upscaler_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -85,21 +81,17 @@ def upscaler_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files(custom_checkpoint_type="upscaler"):
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -252,83 +244,6 @@ def upscaler_inf(
|
||||
yield generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Upscaler Rest API.
|
||||
def upscaler_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = upscaler_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
init_image,
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["noise_level"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -346,36 +261,22 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
upscaler_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
upscaler_model_info = (
|
||||
f"Custom Model Path: {upscaler_model_info}"
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
)
|
||||
upscaler_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=upscaler_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-x4-upscaler",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files(
|
||||
choices=get_custom_model_files(
|
||||
custom_checkpoint_type="upscaler"
|
||||
)
|
||||
+ predefined_upscaler_models,
|
||||
)
|
||||
upscaler_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: SG161222/Realistic_Vision_V1.3, "
|
||||
"https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
upscaler_vae_info = (
|
||||
@@ -390,6 +291,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -425,6 +328,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -443,6 +347,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Scheduler",
|
||||
value="DDIM",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -547,17 +452,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -569,14 +465,26 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{upscaler_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
upscaler_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -600,7 +508,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
|
||||
@@ -170,4 +170,5 @@ def cancel_sd():
|
||||
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
nodicon_loc = resource_path("logos/nod-icon.png")
|
||||
available_devices = get_available_devices()
|
||||
|
||||
105
apps/stable_diffusion/web/utils/app.py
Normal file
105
apps/stable_diffusion/web/utils/app.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import os
|
||||
import sys
|
||||
import webview
|
||||
import webview.util
|
||||
import socket
|
||||
|
||||
from contextlib import closing
|
||||
from multiprocessing import Process
|
||||
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
|
||||
def webview2_installed():
|
||||
if sys.platform != "win32":
|
||||
return False
|
||||
|
||||
# On windows we want to ensure we have MS webview2 available so we don't fall back
|
||||
# to MSHTML (aka ye olde Internet Explorer) which is deprecated by pywebview, and
|
||||
# apparently causes SHARK not to load in properly.
|
||||
|
||||
# Checking these registry entries is how Microsoft says to detect a webview2 installation:
|
||||
# https://learn.microsoft.com/en-us/microsoft-edge/webview2/concepts/distribution
|
||||
import winreg
|
||||
|
||||
path = r"SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
|
||||
|
||||
# only way can find if a registry entry even exists is to try and open it
|
||||
try:
|
||||
# check for an all user install
|
||||
with winreg.OpenKey(
|
||||
winreg.HKEY_LOCAL_MACHINE,
|
||||
path,
|
||||
0,
|
||||
winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY,
|
||||
) as registry_key:
|
||||
value, type = winreg.QueryValueEx(registry_key, "pv")
|
||||
|
||||
# if it didn't exist, we want to continue on...
|
||||
except WindowsError:
|
||||
try:
|
||||
# ...to check for a current user install
|
||||
with winreg.OpenKey(
|
||||
winreg.HKEY_CURRENT_USER,
|
||||
path,
|
||||
0,
|
||||
winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY,
|
||||
) as registry_key:
|
||||
value, type = winreg.QueryValueEx(registry_key, "pv")
|
||||
except WindowsError:
|
||||
value = None
|
||||
finally:
|
||||
return (value is not None) and value != "" and value != "0.0.0.0"
|
||||
|
||||
|
||||
def window(address):
|
||||
from tkinter import Tk
|
||||
|
||||
window = Tk()
|
||||
|
||||
# get screen width and height of display and make it more reasonably
|
||||
# sized as we aren't making it full-screen or maximized
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
def usable_port():
|
||||
# Make sure we can actually use the port given in args.server_port. If
|
||||
# not ask the OS for a port and return that as our port to use.
|
||||
|
||||
port = args.server_port
|
||||
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||
try:
|
||||
sock.bind(("0.0.0.0", port))
|
||||
except OSError:
|
||||
with closing(
|
||||
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
) as sock:
|
||||
sock.bind(("0.0.0.0", 0))
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
return sock.getsockname()[1]
|
||||
|
||||
return port
|
||||
|
||||
|
||||
def launch(port):
|
||||
# setup to launch as an app if app mode has been requested and we're able
|
||||
# to do it, answering whether we succeeded.
|
||||
if args.ui == "app" and (sys.platform != "win32" or webview2_installed()):
|
||||
try:
|
||||
t = Process(target=window, args=[f"http://localhost:{port}"])
|
||||
t.start()
|
||||
return True
|
||||
except webview.util.WebViewException:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
@@ -149,7 +149,6 @@ def import_png_metadata(
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
@@ -175,10 +174,8 @@ def import_png_metadata(
|
||||
|
||||
if "Model" in metadata and png_custom_model:
|
||||
custom_model = png_custom_model
|
||||
hf_model_id = ""
|
||||
if "Model" in metadata and png_hf_model_id:
|
||||
custom_model = "None"
|
||||
hf_model_id = png_hf_model_id
|
||||
elif "Model" in metadata and png_hf_model_id:
|
||||
custom_model = png_hf_model_id
|
||||
|
||||
if "LoRA" in metadata and lora_custom_model:
|
||||
custom_lora = lora_custom_model
|
||||
@@ -217,7 +214,6 @@ def import_png_metadata(
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
|
||||
@@ -129,12 +129,12 @@ pytest_benchmark_param = pytest.mark.parametrize(
|
||||
pytest.param(True, "cpu", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"gpu",
|
||||
"cuda",
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("cuda"), reason="nvidia-smi not found"
|
||||
),
|
||||
),
|
||||
pytest.param(True, "gpu", marks=pytest.mark.skip),
|
||||
pytest.param(True, "cuda", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"vulkan",
|
||||
|
||||
140
docs/shark_sd_koboldcpp.md
Normal file
140
docs/shark_sd_koboldcpp.md
Normal file
@@ -0,0 +1,140 @@
|
||||
# Overview
|
||||
|
||||
In [1.47.2](https://github.com/LostRuins/koboldcpp/releases/tag/v1.47.2) [Koboldcpp](https://github.com/LostRuins/koboldcpp) added AUTOMATIC1111 integration for image generation. Since SHARK implements a small subset of the A1111 REST api, you can also use SHARK for this. This document gives a starting point for how to get this working.
|
||||
|
||||
## In Action
|
||||
|
||||

|
||||
|
||||
## Memory considerations
|
||||
|
||||
Since both Koboldcpp and SHARK will use VRAM on your graphic card(s) running both at the same time using the same card will impose extra limitations on the model size you can fully offload to the video card in Koboldcpp. For me, on a RX 7900 XTX on Windows with 24 GiB of VRAM, the limit was about a 13 Billion parameter model with Q5_K_M quantisation.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
When using SHARK for image generation, especially with Koboldcpp, you need to be aware that it is currently designed to pay a large upfront cost in time compiling and tuning the model you select, to get an optimal individual image generation time. You need to be the judge as to whether this trade-off is going to be worth it for your OS and hardware combination.
|
||||
|
||||
It means that the first time you run a particular Stable Diffusion model for a particular combination of image size, LoRA, and VAE, SHARK will spend *many minutes* - even on a beefy machaine with very fast graphics card with lots of memory - building that model combination just so it can save it to disk. It may even have to go away and download the model if it doesn't already have it locally. Once it has done its build of a model combination for your hardware once, it shouldn't need to do it again until you upgrade to a newer SHARK version, install different drivers or change your graphics hardware. It will just upload the files it generated the first time to your graphics card and proceed from there.
|
||||
|
||||
This does mean however, that on a brand new fresh install of SHARK that has not generated any images on a model you haven't selected before, the first image Koboldcpp requests may look like it is *never* going finish and that the whole process has broken. Be forewarned, make yourself a cup of coffee, and expect a lot of messages about compilation and tuning from SHARK in the terminal you ran it from.
|
||||
|
||||
## Setup SHARK and prerequisites:
|
||||
|
||||
* Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme).
|
||||
* Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install.
|
||||
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_cors_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
|
||||
|
||||
```powershell
|
||||
## Run the .exe in API mode, with CORS support, on the A1111 endpoint port:
|
||||
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_cors_origin="*" --server_port=7860
|
||||
|
||||
## Run trom the base directory of a source clone of SHARK on Windows:
|
||||
.\setup_venv.ps1
|
||||
python .\apps\stable_diffusion\web\index.py --api --api_cors_origin="*" --server_port=7860
|
||||
|
||||
## Run a the base directory of a source clone of SHARK on Linux:
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
python ./apps/stable_diffusion/web/index.py --api --api_cors_origin="*" --server_port=7860
|
||||
|
||||
## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860
|
||||
|
||||
## Since the api respects most applicable SHARK command line arguments for options not specified,
|
||||
## or currently unimplemented by API, there might be some you want to set, as listed in `--help`
|
||||
.\node_ai_shark_studio_20320901_2525.exe --help
|
||||
|
||||
## For instance, the example above, but with a a custom VAE specified
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
|
||||
|
||||
## An example with multiple specific CORS origins
|
||||
python apps/stable_diffusion/web/index.py --api --api_cors_origin="koboldcpp.example.com:7001" --api_cors_origin="koboldcpp.example.com:7002" --server_port=7860
|
||||
```
|
||||
|
||||
SHARK should start in server mode, and you should see something like this:
|
||||
|
||||

|
||||
|
||||
* Note: When running in api mode with `--api`, the .exe will not function as a webUI. Thus, the address or port shown in the terminal output will only be useful for API requests.
|
||||
|
||||
|
||||
## Configure Koboldcpp for local image generation:
|
||||
|
||||
* Get the latest [Koboldcpp](https://github.com/LostRuins/koboldcpp/releases) if you don't already have it. If you have a recent AMD card that has ROCm HIP [support for Windows](https://rocmdocs.amd.com/en/latest/release/windows_support.html#windows-supported-gpus) or [support for Linux](https://rocmdocs.amd.com/en/latest/release/gpu_os_support.html#linux-supported-gpus), you'll likely prefer [YellowRosecx's ROCm fork](https://github.com/YellowRoseCx/koboldcpp-rocm).
|
||||
* Start Koboldcpp in another terminal/Powershell and setup your model configuration. Refer to the [Koboldcpp README](https://github.com/YellowRoseCx/koboldcpp-rocm) for more details on how to do this if this is your first time using Koboldcpp.
|
||||
* Once the main UI has loaded into your browser click the settings button, go to the advanced tab, and then choose *Local A1111* from the generate images dropdown:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
*if you get an error here, see the next section [below](#connecting-to-shark-on-a-different-address-or-port)*
|
||||
|
||||
* A list of Stable Diffusion models available to your SHARK instance should now be listed in the box below *generate images*. The default value will usually be set to `stabilityai/stable-diffusion-2-1-base`. Choose the model you want to use for image generation from the list (but see [performance considerations](#performance-considerations)).
|
||||
* You should now be ready to generate images, either by clicking the 'Add Img' button above the text entry box:
|
||||
|
||||

|
||||
|
||||
...or by selecting the 'Autogenerate' option in the settings:
|
||||
|
||||

|
||||
|
||||
*I often find that even if I have selected autogenerate I have to do an 'add img' to get things started off*
|
||||
|
||||
* There is one final piece of image generation configuration within Koboldcpp you might want to do. This is also in the generate images section of advanced settings. Here there is, not very obviously, a 'style' button:
|
||||
|
||||

|
||||
|
||||
This will bring up a dialog box where you can enter a short text that will sent as a prefix to the Prompt sent to SHARK:
|
||||
|
||||

|
||||
|
||||
|
||||
## Connecting to SHARK on a different address or port
|
||||
|
||||
If you didn't set the port to `--server_port=7860` when starting SHARK, or you are running it on different machine on your network than you are running Koboldcpp, or to where you are running the koboldcpp's kdlite client frontend, then you very likely got the following error:
|
||||
|
||||

|
||||
|
||||
As long as SHARK is running correctly, this means you need to set the url and port to the correct values in Koboldcpp. For instance. to set the port that Koboldcpp looks for an image generator to SHARK's default port of 8080:
|
||||
|
||||
* Select the cog icon the Generate Images section of Advanced settings:
|
||||
|
||||

|
||||
|
||||
* Then edit the port number at the end of the url in the 'A1111 Endpoint Selection' dialog box to read 8080:
|
||||
|
||||

|
||||
|
||||
* Similarly, when running SHARK on a different machine you will need to change host part of the endpoint url to the hostname or ip address where SHARK is running, similarly:
|
||||
|
||||

|
||||
|
||||
## Examples
|
||||
|
||||
Here's how Koboldcpp shows an image being requested:
|
||||
|
||||

|
||||
|
||||
The generated image in context in story mode:
|
||||
|
||||

|
||||
|
||||
And the same image when clicked on:
|
||||
|
||||

|
||||
|
||||
|
||||
## Where to find the images in SHARK
|
||||
|
||||
Even though Koboldcpp requests images at a size of 512x512, it resizes then to 256x256, converts them to `.jpeg`, and only shows them at 200x200 in the main text window. It does this so it can save them compactly embedded in your story as a `data://` uri.
|
||||
|
||||
However the images at the original size are saved by SHARK in its `output_dir` which is usually a folder named for the current date. inside `generated_imgs` folder in the SHARK installation directory.
|
||||
|
||||
You can browse these, either using the Output Gallery tab from within the SHARK web ui:
|
||||
|
||||

|
||||
|
||||
...or by browsing to the `output_dir` in your operating system's file manager:
|
||||
|
||||

|
||||
@@ -8,19 +8,8 @@ torchvision
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
#these dont work ok osx
|
||||
#iree-tools-tflite
|
||||
#iree-tools-xla
|
||||
#iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow-macos
|
||||
tensorflow-metal
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
tensorflow-probability
|
||||
#jax[cpu]
|
||||
|
||||
# tflitehub dependencies.
|
||||
|
||||
@@ -9,23 +9,13 @@ tabulate
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
iree-tools-tflite
|
||||
iree-tools-xla
|
||||
iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
# Modelling and JAX.
|
||||
gin-config
|
||||
tf-nightly
|
||||
keras-nightly
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
diffusers
|
||||
#tensorflow-probability
|
||||
#jax[cpu]
|
||||
|
||||
|
||||
# tflitehub dependencies.
|
||||
Pillow
|
||||
|
||||
# Testing and support.
|
||||
|
||||
@@ -41,10 +41,12 @@ tiktoken # for codegen
|
||||
joblib # for langchain
|
||||
timm # for MiniGPT4
|
||||
langchain
|
||||
einops # for zoedepth
|
||||
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
pyinstaller
|
||||
|
||||
# vicuna quantization
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@dev
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
|
||||
|
||||
@@ -4,7 +4,7 @@ import base64
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def upscaler_test():
|
||||
def upscaler_test(verbose=False):
|
||||
# Define values here
|
||||
prompt = ""
|
||||
negative_prompt = ""
|
||||
@@ -44,10 +44,17 @@ def upscaler_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"response from server was : {res.status_code}")
|
||||
print(
|
||||
f"[upscaler] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
|
||||
def img2img_test():
|
||||
def img2img_test(verbose=False):
|
||||
# Define values here
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
@@ -87,7 +94,16 @@ def img2img_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"response from server was : {res.status_code}")
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[img2img] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
# NOTE Uncomment below to save the picture
|
||||
|
||||
@@ -103,7 +119,7 @@ def img2img_test():
|
||||
# response_img.save(r"rest_api_tests/response_img.png")
|
||||
|
||||
|
||||
def inpainting_test():
|
||||
def inpainting_test(verbose=False):
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
@@ -150,10 +166,17 @@ def inpainting_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[Inpainting] response from server was : {res.status_code}")
|
||||
print(
|
||||
f"[inpaint] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
|
||||
def outpainting_test():
|
||||
def outpainting_test(verbose=False):
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
@@ -200,10 +223,17 @@ def outpainting_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[Outpaint] response from server was : {res.status_code}")
|
||||
print(
|
||||
f"[outpaint] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
|
||||
def txt2img_test():
|
||||
def txt2img_test(verbose=False):
|
||||
prompt = "Paint a rabbit in a top hate"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
@@ -232,12 +262,119 @@ def txt2img_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[txt2img] response from server was : {res.status_code}")
|
||||
print(
|
||||
f"[txt2img] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
|
||||
def sd_models_test(verbose=False):
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/sd-models"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
res = requests.get(url=url, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[sd_models] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
|
||||
|
||||
|
||||
def sd_samplers_test(verbose=False):
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/samplers"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
res = requests.get(url=url, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[sd_samplers] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
|
||||
|
||||
|
||||
def options_test(verbose=False):
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/options"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
res = requests.get(url=url, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[options] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
|
||||
|
||||
|
||||
def cmd_flags_test(verbose=False):
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/cmd-flags"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
res = requests.get(url=url, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[cmd-flags] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
txt2img_test()
|
||||
img2img_test()
|
||||
upscaler_test()
|
||||
inpainting_test()
|
||||
outpainting_test()
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Exercises the Stable Diffusion REST API of Shark. Make sure "
|
||||
"Shark is running in API mode on 127.0.0.1:8080 before running"
|
||||
"this script."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help=(
|
||||
"also display selected info from the JSON response for "
|
||||
"successful requests"
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sd_models_test(args.verbose)
|
||||
sd_samplers_test(args.verbose)
|
||||
options_test(args.verbose)
|
||||
cmd_flags_test(args.verbose)
|
||||
txt2img_test(args.verbose)
|
||||
img2img_test(args.verbose)
|
||||
upscaler_test(args.verbose)
|
||||
inpainting_test(args.verbose)
|
||||
outpainting_test(args.verbose)
|
||||
|
||||
@@ -86,6 +86,7 @@ $PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
|
||||
if [ "$torch_mlir_bin" = true ]; then
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
|
||||
$PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
@@ -128,7 +129,13 @@ if [[ ! -z "${IMPORTER}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/cpu/
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
PYTORCH_URL=https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
PYTORCH_URL=https://download.pytorch.org/whl/nightly/cpu/
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL}
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
|
||||
T_VER=$($PYTHON -m pip show torch | grep Version)
|
||||
|
||||
@@ -177,7 +177,7 @@ def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
mlir_model = str(module)
|
||||
func_name = "forward"
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if __name__ == "__main__":
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, mlir_dialect="mhlo")
|
||||
shark_module = SharkInference(minilm_mlir, mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
output_idx = 0
|
||||
data_idx = 1
|
||||
|
||||
@@ -6,7 +6,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"
|
||||
mlir_model, device="cpu", mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
|
||||
@@ -13,9 +13,7 @@ arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
|
||||
print("Running shark on cpu backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
||||
|
||||
# Generate the random inputs and feed into the graph.
|
||||
x = shark_module.generate_random_inputs()
|
||||
@@ -23,15 +21,11 @@ shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on cuda backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module = SharkInference(mhlo_ir, device="cuda", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on vulkan backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module = SharkInference(mhlo_ir, device="vulkan", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
@@ -8,9 +8,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
)
|
||||
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module = SharkInference(mlir_model, device="cpu", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
|
||||
@@ -33,7 +33,7 @@ mlir_importer = SharkImporter(
|
||||
|
||||
print(golden_out)
|
||||
|
||||
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
|
||||
shark_module = SharkInference(vision_mlir, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input,))
|
||||
print("Obtained result", result)
|
||||
|
||||
@@ -49,9 +49,7 @@ module = torch_mlir.compile(
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cuda", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module = SharkInference(mlir_model, device="cuda", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
|
||||
|
||||
|
||||
@@ -360,7 +360,7 @@ mlir_importer = SharkImporter(
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
dlrm_mlir, func_name, device="vulkan", mlir_dialect="linalg"
|
||||
dlrm_mlir, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(input_dlrm)
|
||||
|
||||
@@ -294,7 +294,7 @@ def test_dlrm() -> None:
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
dlrm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
dlrm_mlir, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
|
||||
@@ -33,7 +33,7 @@ mlir_importer = SharkImporter(
|
||||
tracing_required=False
|
||||
)
|
||||
|
||||
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
|
||||
shark_module = SharkInference(vision_mlir, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input,))
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
|
||||
@@ -7,7 +7,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="vulkan", mlir_dialect="linalg"
|
||||
mlir_model, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
|
||||
@@ -19,9 +19,12 @@ import sys
|
||||
import subprocess
|
||||
|
||||
|
||||
def run_cmd(cmd, debug=False):
|
||||
def run_cmd(cmd, debug=False, raise_err=False):
|
||||
"""
|
||||
Inputs: cli command string.
|
||||
Inputs:
|
||||
cmd : cli command string.
|
||||
debug : if True, prints debug info
|
||||
raise_err : if True, raise exception to caller
|
||||
"""
|
||||
if debug:
|
||||
print("IREE run command: \n\n")
|
||||
@@ -39,8 +42,11 @@ def run_cmd(cmd, debug=False):
|
||||
stderr = result.stderr.decode()
|
||||
return stdout, stderr
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
if raise_err:
|
||||
raise Exception from e
|
||||
else:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
|
||||
|
||||
def iree_device_map(device):
|
||||
@@ -95,38 +101,31 @@ _IREE_TARGET_MAP = {
|
||||
# Finds whether the required drivers are installed for the given device.
|
||||
@functools.cache
|
||||
def check_device_drivers(device):
|
||||
"""Checks necessary drivers present for gpu and vulkan devices"""
|
||||
"""
|
||||
Checks necessary drivers present for gpu and vulkan devices
|
||||
False => drivers present!
|
||||
"""
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
subprocess.check_output("nvidia-smi")
|
||||
except Exception:
|
||||
return True
|
||||
elif device in ["vulkan"]:
|
||||
try:
|
||||
subprocess.check_output("vulkaninfo")
|
||||
except Exception:
|
||||
return True
|
||||
elif device == "metal":
|
||||
return False
|
||||
elif device in ["intel-gpu"]:
|
||||
try:
|
||||
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])
|
||||
return False
|
||||
except Exception:
|
||||
return True
|
||||
elif device == "cpu":
|
||||
return False
|
||||
elif device == "rocm":
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
subprocess.check_output("hipinfo")
|
||||
else:
|
||||
subprocess.check_output("rocminfo")
|
||||
except Exception:
|
||||
return True
|
||||
from iree.runtime import get_driver
|
||||
|
||||
device_mapped = iree_device_map(device)
|
||||
|
||||
try:
|
||||
_ = get_driver(device_mapped)
|
||||
except ValueError as ve:
|
||||
print(
|
||||
f"[ERR] device `{device}` not registered with IREE. "
|
||||
"Ensure IREE is configured for use with this device.\n"
|
||||
f"Full Error: \n {repr(ve)}"
|
||||
)
|
||||
return True
|
||||
except RuntimeError as re:
|
||||
print(
|
||||
f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}"
|
||||
)
|
||||
return True
|
||||
|
||||
# Unknown device. We assume drivers are installed.
|
||||
return False
|
||||
@@ -134,11 +133,32 @@ def check_device_drivers(device):
|
||||
|
||||
# Installation info for the missing device drivers.
|
||||
def device_driver_info(device):
|
||||
if device == "cuda":
|
||||
return "nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
|
||||
elif device in ["metal", "vulkan"]:
|
||||
return "vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
|
||||
elif device == "rocm":
|
||||
return "rocm info not found. Please install rocm"
|
||||
device_driver_err_map = {
|
||||
"cuda": {
|
||||
"debug": "Try `nvidia-smi` on system to check.",
|
||||
"solution": " from https://www.nvidia.in/Download/index.aspx?lang=en-in for your system.",
|
||||
},
|
||||
"vulkan": {
|
||||
"debug": "Try `vulkaninfo` on system to check.",
|
||||
"solution": " from https://vulkan.lunarg.com/sdk/home for your distribution.",
|
||||
},
|
||||
"metal": {
|
||||
"debug": "Check if Bare metal is supported and enabled on your system.",
|
||||
"solution": ".",
|
||||
},
|
||||
"rocm": {
|
||||
"debug": f"Try `{'hip' if sys.platform == 'win32' else 'rocm'}info` on system to check.",
|
||||
"solution": " from https://rocm.docs.amd.com/en/latest/rocm.html for your system.",
|
||||
},
|
||||
}
|
||||
|
||||
if device in device_driver_err_map:
|
||||
err_msg = (
|
||||
f"Required drivers for {device} not found. {device_driver_err_map[device]['debug']} "
|
||||
f"Please install the required drivers{device_driver_err_map[device]['solution']} "
|
||||
f"For further assistance please reach out to the community on discord [https://discord.com/invite/RUqY2h2s9u]"
|
||||
f" and/or file a bug at https://github.com/nod-ai/SHARK/issues"
|
||||
)
|
||||
return err_msg
|
||||
else:
|
||||
return f"{device} is not supported."
|
||||
|
||||
@@ -39,11 +39,19 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
f"Specific device selection only supported for vulkan now."
|
||||
f"Proceeding with {device} as device."
|
||||
)
|
||||
device_num = device_uri[1]
|
||||
# device_uri can be device_num or device_path.
|
||||
# assuming number of devices for a single driver will be not be >99
|
||||
if len(device_uri[1]) <= 2:
|
||||
# expected to be device index in range 0 - 99
|
||||
device_num = int(device_uri[1])
|
||||
else:
|
||||
# expected to be device path
|
||||
device_num = device_uri[1]
|
||||
|
||||
else:
|
||||
device_num = 0
|
||||
|
||||
if device_uri[0] == "cpu":
|
||||
if "cpu" in device:
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
data_tiling_flag = ["--iree-opt-data-tiling"]
|
||||
@@ -55,6 +63,8 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
+ data_tiling_flag
|
||||
+ u_kernel_flag
|
||||
+ stack_size_flag
|
||||
+ ["--iree-flow-enable-quantized-matmul-reassociation"]
|
||||
+ ["--iree-llvmcpu-enable-quantized-matmul-reassociation"]
|
||||
)
|
||||
if device_uri[0] == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
@@ -73,7 +83,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args()
|
||||
return get_iree_rocm_args(extra_args=extra_args)
|
||||
return []
|
||||
|
||||
|
||||
@@ -292,9 +302,10 @@ def compile_module_to_flatbuffer(
|
||||
extra_args,
|
||||
model_name="None",
|
||||
debug=False,
|
||||
compile_str=False,
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
input_type = "auto"
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args(debug=debug)
|
||||
@@ -311,10 +322,7 @@ def compile_module_to_flatbuffer(
|
||||
elif frontend in ["tm_tensor"]:
|
||||
input_type = ireec.InputType.TM_TENSOR
|
||||
|
||||
# TODO: make it simpler.
|
||||
# Compile according to the input type, else just try compiling.
|
||||
if input_type != "":
|
||||
# Currently for MHLO/TOSA.
|
||||
if compile_str:
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[iree_target_map(device)],
|
||||
@@ -322,9 +330,10 @@ def compile_module_to_flatbuffer(
|
||||
input_type=input_type,
|
||||
)
|
||||
else:
|
||||
# Currently for Torch.
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
assert os.path.isfile(module)
|
||||
flatbuffer_blob = ireec.compile_file(
|
||||
str(module),
|
||||
input_type=input_type,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=args,
|
||||
)
|
||||
@@ -332,8 +341,12 @@ def compile_module_to_flatbuffer(
|
||||
return flatbuffer_blob
|
||||
|
||||
|
||||
def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
def get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=None, rt_flags: list = []
|
||||
):
|
||||
# Returns the compiled module and the configs.
|
||||
for flag in rt_flags:
|
||||
ireert.flags.parse_flag(flag)
|
||||
if device_idx is not None:
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
@@ -355,9 +368,22 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
|
||||
|
||||
def load_vmfb_using_mmap(
|
||||
flatbuffer_blob_or_path, device: str, device_idx: int = None
|
||||
flatbuffer_blob_or_path,
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
print(f"Loading module {flatbuffer_blob_or_path}...")
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
rt_flags.append(flag)
|
||||
for flag in rt_flags:
|
||||
print(flag)
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
if "rocm" in device:
|
||||
device = "rocm"
|
||||
with DetailLogger(timeout=2.5) as dl:
|
||||
@@ -374,6 +400,9 @@ def load_vmfb_using_mmap(
|
||||
)
|
||||
dl.log(f"ireert.create_device()")
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
dl.log(f"ireert.Config()")
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
@@ -384,6 +413,7 @@ def load_vmfb_using_mmap(
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
@@ -403,7 +433,14 @@ def load_vmfb_using_mmap(
|
||||
)
|
||||
dl.log(f"mmap {flatbuffer_blob_or_path}")
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
for flag in shark_args.additional_runtime_args:
|
||||
ireert.flags.parse_flags(flag)
|
||||
dl.log(f"ireert.SystemContext created")
|
||||
if "vulkan" in device:
|
||||
# Vulkan pipeline creation consumes significant amount of time.
|
||||
print(
|
||||
"\tCompiling Vulkan shaders. This may take a few minutes."
|
||||
)
|
||||
ctx.add_vm_module(mmaped_vmfb)
|
||||
dl.log(f"module initialized")
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
@@ -424,13 +461,21 @@ def get_iree_compiled_module(
|
||||
frontend: str = "torch",
|
||||
model_config_path: str = None,
|
||||
extra_args: list = [],
|
||||
rt_flags: list = [],
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
debug: bool = False,
|
||||
compile_str: bool = False,
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, model_config_path, extra_args, debug
|
||||
module=module,
|
||||
device=device,
|
||||
frontend=frontend,
|
||||
model_config_path=model_config_path,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=compile_str,
|
||||
)
|
||||
temp_file_to_unlink = None
|
||||
# TODO: Currently mmap=True control flow path has been switched off for mmap.
|
||||
@@ -439,11 +484,14 @@ def get_iree_compiled_module(
|
||||
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
|
||||
if mmap:
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_blob, device, device_idx
|
||||
flatbuffer_blob, device, device_idx, rt_flags
|
||||
)
|
||||
else:
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=device_idx
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=device_idx,
|
||||
rt_flags=rt_flags,
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
@@ -458,17 +506,21 @@ def load_flatbuffer(
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
temp_file_to_unlink = None
|
||||
if mmap:
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_path, device, device_idx
|
||||
flatbuffer_path, device, device_idx, rt_flags
|
||||
)
|
||||
else:
|
||||
with open(os.path.join(flatbuffer_path), "rb") as f:
|
||||
flatbuffer_blob = f.read()
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=device_idx
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=device_idx,
|
||||
rt_flags=rt_flags,
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
@@ -487,10 +539,17 @@ def export_iree_module_to_vmfb(
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
debug: bool = False,
|
||||
compile_str: bool = False,
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, mlir_dialect, model_config_path, extra_args, debug
|
||||
module=module,
|
||||
device=device,
|
||||
frontend=mlir_dialect,
|
||||
model_config_path=model_config_path,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=compile_str,
|
||||
)
|
||||
if module_name is None:
|
||||
device_name = (
|
||||
@@ -526,10 +585,17 @@ def get_results(
|
||||
frontend="torch",
|
||||
send_to_host=True,
|
||||
debug_timeout: float = 5.0,
|
||||
device: str = None,
|
||||
):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
with DetailLogger(debug_timeout) as dl:
|
||||
device_inputs = []
|
||||
if device == "rocm" and hasattr(config, "id"):
|
||||
haldriver = ireert.get_driver("rocm")
|
||||
haldevice = haldriver.create_device(
|
||||
config.id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
for input_array in input:
|
||||
dl.log(f"Load to device: {input_array.shape}")
|
||||
device_inputs.append(
|
||||
@@ -566,7 +632,7 @@ def get_results(
|
||||
def get_iree_runtime_config(device):
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
if device == "metal" and shark_args.device_allocator == "caching":
|
||||
if "metal" in device and shark_args.device_allocator == "caching":
|
||||
print(
|
||||
"[WARNING] metal devices can not have a `caching` allocator."
|
||||
"\nUsing default allocator `None`"
|
||||
@@ -574,7 +640,9 @@ def get_iree_runtime_config(device):
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
|
||||
allocators=shark_args.device_allocator if device != "metal" else None,
|
||||
allocators=shark_args.device_allocator
|
||||
if "metal" not in device
|
||||
else None,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
return config
|
||||
|
||||
@@ -18,7 +18,9 @@ import functools
|
||||
import iree.runtime as ireert
|
||||
import ctypes
|
||||
import sys
|
||||
from subprocess import CalledProcessError
|
||||
from shark.parser import shark_args
|
||||
from shark.iree_utils._common import run_cmd
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
@@ -40,55 +42,74 @@ def get_iree_gpu_args():
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
@functools.cache
|
||||
def get_iree_rocm_args():
|
||||
def get_iree_rocm_args(device_num=0, extra_args=[]):
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
# get arch from hipinfo.
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
rocm_flags = ["--iree-rocm-link-bc=true"]
|
||||
|
||||
if sys.platform == "win32":
|
||||
if "HIP_PATH" in os.environ:
|
||||
rocm_path = os.environ["HIP_PATH"]
|
||||
print(f"Found a ROCm installation at {rocm_path}.")
|
||||
else:
|
||||
print("Failed to find ROCM_PATH. Defaulting to C:\\AMD\\ROCM\\5.5")
|
||||
rocm_path = "C:\\AMD\\ROCM\\5.5"
|
||||
else:
|
||||
if "ROCM_PATH" in os.environ:
|
||||
rocm_path = os.environ["ROCM_PATH"]
|
||||
print(f"Found a ROCm installation at {rocm_path}.")
|
||||
else:
|
||||
print("Failed to find ROCM_PATH. Defaulting to /opt/rocm")
|
||||
rocm_path = "/opt/rocm/"
|
||||
# ROCM Device Arch selection:
|
||||
# 1 : User given device arch using `--iree-rocm-target-chip` flag
|
||||
# 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index <device_num>
|
||||
# 3 : default arch : gfx1100
|
||||
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
rocm_arch = re.search(
|
||||
r"gfx\d{3,}",
|
||||
subprocess.check_output("hipinfo", shell=True, text=True),
|
||||
).group(0)
|
||||
else:
|
||||
rocm_arch = re.match(
|
||||
r".*(gfx\w+)",
|
||||
subprocess.check_output(
|
||||
"rocminfo | grep -i 'gfx'", shell=True, text=True
|
||||
),
|
||||
).group(1)
|
||||
print(f"Found rocm arch {rocm_arch}...")
|
||||
except:
|
||||
default_rocm_arch = "gfx_1100"
|
||||
# Check if the target arch flag for rocm device present in extra_flags
|
||||
flag_present = False
|
||||
for flag in extra_args:
|
||||
if "iree-rocm-target-chip" in flag:
|
||||
flag_present = True
|
||||
print(
|
||||
f"User Specified rocm target device arch from flag : {flag.split('=')[1]} will be used"
|
||||
)
|
||||
|
||||
arch_in_device_dump = None
|
||||
if not flag_present:
|
||||
# get rocm arch from iree dump devices
|
||||
def get_devices_info_from_dump(dump):
|
||||
from os import linesep
|
||||
|
||||
dump_clean = list(
|
||||
filter(
|
||||
lambda s: "--device=rocm" in s or "gpu-arch-name:" in s,
|
||||
dump.split(linesep),
|
||||
)
|
||||
)
|
||||
arch_pairs = [
|
||||
(
|
||||
dump_clean[i].split("=")[1].strip(),
|
||||
dump_clean[i + 1].split(":")[1].strip(),
|
||||
)
|
||||
for i in range(0, len(dump_clean), 2)
|
||||
]
|
||||
return arch_pairs
|
||||
|
||||
dump_device_info = None
|
||||
try:
|
||||
dump_device_info = run_cmd(
|
||||
"iree-run-module --dump_devices=rocm", raise_err=True
|
||||
)
|
||||
except Exception as e:
|
||||
print("could not execute `iree-run-module --dump_devices=rocm`")
|
||||
|
||||
if dump_device_info is not None:
|
||||
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
|
||||
if (
|
||||
len(device_arch_pairs) > device_num
|
||||
): # can find arch in the list
|
||||
arch_in_device_dump = device_arch_pairs[device_num][1]
|
||||
|
||||
if arch_in_device_dump is not None:
|
||||
print(f"Found ROCm device arch : {arch_in_device_dump}")
|
||||
rocm_flags.append(f"--iree-rocm-target-chip={arch_in_device_dump}")
|
||||
|
||||
if not flag_present and arch_in_device_dump is None:
|
||||
print(
|
||||
"Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100."
|
||||
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
|
||||
"\n or from `iree-run-module --dump_devices=rocm` command."
|
||||
f"\nUsing {default_rocm_arch} as ROCm arch for compilation."
|
||||
)
|
||||
rocm_arch = "gfx1100"
|
||||
rocm_flags.append(f"--iree-rocm-target-chip={default_rocm_arch}")
|
||||
|
||||
bc_path = os.path.join(rocm_path, "amdgcn", "bitcode")
|
||||
return [
|
||||
f"--iree-rocm-target-chip={rocm_arch}",
|
||||
"--iree-rocm-link-bc=true",
|
||||
f"--iree-rocm-bc-dir={bc_path}",
|
||||
]
|
||||
return rocm_flags
|
||||
|
||||
|
||||
# Some constants taken from cuda.h
|
||||
|
||||
@@ -89,24 +89,10 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
|
||||
|
||||
def get_iree_metal_args(device_num=0, extra_args=[]):
|
||||
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
# Add any metal spefic compilation flags here
|
||||
res_metal_flag = []
|
||||
metal_triple_flag = None
|
||||
for arg in extra_args:
|
||||
if "-iree-metal-target-platform=" in arg:
|
||||
print(f"Using target triple {arg} from command line args")
|
||||
metal_triple_flag = arg
|
||||
break
|
||||
|
||||
if metal_triple_flag is None:
|
||||
metal_triple_flag = get_metal_triple_flag(extra_args=extra_args)
|
||||
|
||||
if metal_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(
|
||||
"-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
)
|
||||
res_metal_flag.append(vulkan_target_env)
|
||||
if len(extra_args) > 0:
|
||||
res_metal_flag.extend(extra_args)
|
||||
return res_metal_flag
|
||||
|
||||
|
||||
|
||||
@@ -27,9 +27,12 @@ from shark.parser import shark_args
|
||||
def get_all_vulkan_devices():
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver("vulkan")
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
try:
|
||||
driver = get_driver("vulkan")
|
||||
device_list_src = driver.query_available_devices()
|
||||
except:
|
||||
device_list_src = {}
|
||||
|
||||
return [d["name"] for d in device_list_src]
|
||||
|
||||
|
||||
@@ -68,6 +71,8 @@ def get_vulkan_target_triple(device_name):
|
||||
Returns:
|
||||
str or None: target triple or None if no match found for given name
|
||||
"""
|
||||
|
||||
# TODO: Replace this with a dict or something smarter.
|
||||
system_os = get_os_name()
|
||||
# Apple Targets
|
||||
if all(x in device_name for x in ("Apple", "M1")):
|
||||
@@ -117,8 +122,12 @@ def get_vulkan_target_triple(device_name):
|
||||
# Amd Targets
|
||||
# Linux: Radeon RX 7900 XTX
|
||||
# Windows: AMD Radeon RX 7900 XTX
|
||||
elif all(x in device_name for x in ("RX", "7800")):
|
||||
triple = f"rdna3-7800-{system_os}"
|
||||
elif all(x in device_name for x in ("RX", "7900")):
|
||||
triple = f"rdna3-7900-{system_os}"
|
||||
elif all(x in device_name for x in ("Radeon", "780M")):
|
||||
triple = f"rdna3-780m-{system_os}"
|
||||
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
|
||||
triple = f"rdna3-w7900-{system_os}"
|
||||
elif any(x in device_name for x in ("AMD", "Radeon")):
|
||||
|
||||
@@ -26,7 +26,7 @@ class SplitStrToListAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
del parser, option_string
|
||||
setattr(namespace, self.dest, shlex.split(values[0]))
|
||||
setattr(namespace, self.dest, shlex.split(" "))
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="SHARK runner.")
|
||||
@@ -44,6 +44,13 @@ parser.add_argument(
|
||||
action=SplitStrToListAction,
|
||||
help="Additional arguments to pass to the compiler. These are appended as the last arguments.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional_runtime_args",
|
||||
default=list(),
|
||||
nargs=1,
|
||||
action=SplitStrToListAction,
|
||||
help="Additional arguments to pass to the IREE runtime. These are appended as the last arguments.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_tf32",
|
||||
type=bool,
|
||||
|
||||
@@ -84,6 +84,13 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
self.extra_args = extra_args
|
||||
self.import_args = {}
|
||||
self.temp_file_to_unlink = None
|
||||
if not os.path.isfile(mlir_module):
|
||||
print(
|
||||
"Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
|
||||
)
|
||||
self.compile_str = True
|
||||
else:
|
||||
self.compile_str = False
|
||||
SharkRunner.__init__(
|
||||
self,
|
||||
mlir_module,
|
||||
@@ -98,6 +105,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
".",
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
params = load_flatbuffer(
|
||||
self.vmfb_file,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import tempfile
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
@@ -130,10 +130,17 @@ def compile_int_precision(
|
||||
mlir_module = mlir_module.encode("UTF-8")
|
||||
mlir_module = BytesIO(mlir_module)
|
||||
bytecode = mlir_module.read()
|
||||
bytecode_path = os.path.join(
|
||||
os.getcwd(), f"{extended_model_name}_linalg.mlirbc"
|
||||
)
|
||||
with open(bytecode_path, "wb") as f:
|
||||
f.write(bytecode)
|
||||
del bytecode
|
||||
del mlir_module
|
||||
print(f"Elided IR written for {extended_model_name}")
|
||||
return bytecode
|
||||
return bytecode_path
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
mlir_module=bytecode_path, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
@@ -148,7 +155,7 @@ def compile_int_precision(
|
||||
generate_vmfb=generate_vmfb,
|
||||
extra_args=extra_args,
|
||||
),
|
||||
bytecode,
|
||||
bytecode_path,
|
||||
)
|
||||
|
||||
|
||||
@@ -201,7 +208,7 @@ def shark_compile_through_fx(
|
||||
]
|
||||
else:
|
||||
(
|
||||
mlir_module,
|
||||
bytecode,
|
||||
_,
|
||||
) = import_with_fx(
|
||||
model=model,
|
||||
@@ -212,6 +219,11 @@ def shark_compile_through_fx(
|
||||
model_name=extended_model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
mlir_module = save_mlir(
|
||||
mlir_module=bytecode,
|
||||
model_name=extended_model_name,
|
||||
mlir_dialect=mlir_dialect,
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
|
||||
@@ -275,11 +275,11 @@ def download_model(
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
tuned_str = "" if tuned is None else "_" + tuned
|
||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
mlir_filename = os.path.join(model_dir, model_name + suffix)
|
||||
print(
|
||||
f"Verifying that model artifacts were downloaded successfully to {filename}..."
|
||||
f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..."
|
||||
)
|
||||
if not os.path.exists(filename):
|
||||
if not os.path.exists(mlir_filename):
|
||||
from tank.generate_sharktank import gen_shark_files
|
||||
|
||||
print(
|
||||
@@ -287,13 +287,11 @@ def download_model(
|
||||
)
|
||||
gen_shark_files(model_name, frontend, WORKDIR, import_args)
|
||||
|
||||
assert os.path.exists(filename), f"MLIR not found at {filename}"
|
||||
with open(filename, mode="rb") as f:
|
||||
mlir_file = f.read()
|
||||
assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}"
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
return mlir_filename, function_name, inputs_tuple, golden_out_tuple
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
import torchvision.models as models
|
||||
import copy
|
||||
import io
|
||||
@@ -20,10 +20,16 @@ def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
bytecode_path = save_mlir(
|
||||
bytecode,
|
||||
model_name="shark_eager_module",
|
||||
frontend="torch",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
mlir_module=bytecode_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
|
||||
@@ -3,8 +3,8 @@ import json
|
||||
import numpy as np
|
||||
|
||||
import torch_mlir
|
||||
from iree.compiler import compile_str
|
||||
from shark.shark_importer import import_with_fx, get_f16_inputs
|
||||
from iree.compiler import compile_file
|
||||
from shark.shark_importer import import_with_fx, get_f16_inputs, save_mlir
|
||||
|
||||
|
||||
class GenerateConfigFile:
|
||||
@@ -54,9 +54,15 @@ class GenerateConfigFile:
|
||||
verbose=False,
|
||||
)
|
||||
module = module.operation.get_asm(large_elements_limit=4)
|
||||
module_file = save_mlir(
|
||||
module,
|
||||
model_name="module_pre_split",
|
||||
frontend="torch",
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
compiled_module_str = str(
|
||||
compile_str(
|
||||
str(module),
|
||||
compile_file(
|
||||
module_file,
|
||||
target_backends=[backend],
|
||||
extra_args=[
|
||||
"--compile-to=flow",
|
||||
|
||||
@@ -451,6 +451,108 @@ def transform_fx(fx_g, quantized=False):
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
def gptq_transforms(fx_g):
|
||||
import torch
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.ones,
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("device") == torch.device(device="cuda:0"):
|
||||
updated_kwargs = node.kwargs.copy()
|
||||
updated_kwargs["device"] = torch.device(device="cpu")
|
||||
node.kwargs = updated_kwargs
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("dtype") == torch.bfloat16:
|
||||
updated_kwargs = node.kwargs.copy()
|
||||
updated_kwargs["dtype"] = torch.float16
|
||||
node.kwargs = updated_kwargs
|
||||
|
||||
# Inputs of aten.native_layer_norm should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.native_layer_norm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (
|
||||
new_node_arg0,
|
||||
node.args[1],
|
||||
node.args[2],
|
||||
node.args[3],
|
||||
node.args[4],
|
||||
)
|
||||
|
||||
# Inputs of aten.mm should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.mm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
new_node_arg1 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[1], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node_arg0, new_node_arg1)
|
||||
|
||||
# Outputs of aten.mm should be downcasted to fp16.
|
||||
if type(node.args[0]) == torch.fx.node.Node and node.args[
|
||||
0
|
||||
].target in [torch.ops.aten.mm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
tmp = node.args[0]
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node.args[0],),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.args[0].append(new_node)
|
||||
node.args[0].replace_all_uses_with(new_node)
|
||||
new_node.args = (tmp,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
# Inputs of aten._softmax should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten._softmax]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node_arg0, node.args[1], node.args[2])
|
||||
|
||||
# Outputs of aten._softmax should be downcasted to fp16.
|
||||
if (
|
||||
type(node.args[0]) == torch.fx.node.Node
|
||||
and node.args[0].target in [torch.ops.aten._softmax]
|
||||
and node.target in [torch.ops.aten.expand]
|
||||
):
|
||||
with fx_g.graph.inserting_before(node):
|
||||
tmp = node.args[0]
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node.args[0],),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.args[0].append(new_node)
|
||||
node.args[0].replace_all_uses_with(new_node)
|
||||
new_node.args = (tmp,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
# Doesn't replace the None type.
|
||||
def change_fx_graph_return_to_tuple(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
@@ -504,6 +606,7 @@ def import_with_fx(
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
precision="fp32",
|
||||
is_gptq=False,
|
||||
):
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@@ -584,7 +687,7 @@ def import_with_fx(
|
||||
torch.ops.aten.index_add,
|
||||
torch.ops.aten.index_add_,
|
||||
]
|
||||
if precision in ["int4", "int8"]:
|
||||
if precision in ["int4", "int8"] and not is_gptq:
|
||||
from brevitas_examples.llm.llm_quant.export import (
|
||||
block_quant_layer_level_manager,
|
||||
)
|
||||
@@ -653,6 +756,10 @@ def import_with_fx(
|
||||
add_upcast(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if is_gptq:
|
||||
gptq_transforms(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if mlir_type == "fx":
|
||||
return fx_g
|
||||
|
||||
@@ -685,3 +792,25 @@ def import_with_fx(
|
||||
|
||||
mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type)
|
||||
return mlir_module, func_name
|
||||
|
||||
|
||||
# Saves a .mlir module python object to the directory 'dir' with 'model_name' and returns a path to the saved file.
|
||||
def save_mlir(
|
||||
mlir_module,
|
||||
model_name,
|
||||
mlir_dialect="linalg",
|
||||
frontend="torch",
|
||||
dir=tempfile.gettempdir(),
|
||||
):
|
||||
model_name_mlir = (
|
||||
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
|
||||
)
|
||||
if dir == "":
|
||||
dir = tempfile.gettempdir()
|
||||
mlir_path = os.path.join(dir, model_name_mlir)
|
||||
print(f"saving {model_name_mlir} to {dir}")
|
||||
if frontend == "torch":
|
||||
with open(mlir_path, "wb") as mlir_file:
|
||||
mlir_file.write(mlir_module)
|
||||
|
||||
return mlir_path
|
||||
|
||||
@@ -39,7 +39,7 @@ class SharkInference:
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
|
||||
mlir_module or path represented in string; modules from torch-mlir are serialized in bytecode format.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
@@ -65,7 +65,7 @@ class SharkInference:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: bytes,
|
||||
mlir_module,
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
is_benchmark: bool = False,
|
||||
@@ -73,8 +73,17 @@ class SharkInference:
|
||||
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
|
||||
device_idx: int = None,
|
||||
mmap: bool = True,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
if mlir_module is not None:
|
||||
if mlir_module and not os.path.isfile(mlir_module):
|
||||
print(
|
||||
"Warning: Initializing SharkInference with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
|
||||
)
|
||||
self.compile_str = True
|
||||
else:
|
||||
self.compile_str = False
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.is_benchmark = is_benchmark
|
||||
@@ -92,6 +101,7 @@ class SharkInference:
|
||||
|
||||
self.shark_runner = None
|
||||
self.mmap = mmap
|
||||
self.rt_flags = rt_flags
|
||||
|
||||
def compile(self, extra_args=[]):
|
||||
if self.dispatch_benchmarks is not None:
|
||||
@@ -126,6 +136,7 @@ class SharkInference:
|
||||
self.mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
device_idx=self.device_idx,
|
||||
rt_flags=self.rt_flags,
|
||||
)
|
||||
|
||||
if self.dispatch_benchmarks is not None:
|
||||
@@ -139,11 +150,15 @@ class SharkInference:
|
||||
|
||||
# inputs are considered to be tuple of np.array.
|
||||
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run(function_name, inputs, send_to_host)
|
||||
return self.shark_runner.run(
|
||||
function_name, inputs, send_to_host, device=self.device
|
||||
)
|
||||
|
||||
# forward function.
|
||||
def forward(self, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run("forward", inputs, send_to_host)
|
||||
return self.shark_runner.run(
|
||||
"forward", inputs, send_to_host, device=self.device
|
||||
)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
def get_functions_in_module(self):
|
||||
@@ -203,6 +218,7 @@ class SharkInference:
|
||||
module_name=module_name,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
|
||||
# load and return the module.
|
||||
@@ -211,12 +227,14 @@ class SharkInference:
|
||||
device=self.device,
|
||||
compile_vmfb=False,
|
||||
extra_args=extra_args,
|
||||
rt_flags=self.rt_flags,
|
||||
)
|
||||
params = load_flatbuffer(
|
||||
path,
|
||||
self.device,
|
||||
self.device_idx,
|
||||
mmap=self.mmap,
|
||||
rt_flags=self.rt_flags,
|
||||
)
|
||||
self.shark_runner.iree_compilation_module = params["vmfb"]
|
||||
self.shark_runner.iree_config = params["config"]
|
||||
|
||||
@@ -45,7 +45,7 @@ class SharkRunner:
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string.
|
||||
mlir_module path, string, or bytecode.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
@@ -72,12 +72,22 @@ class SharkRunner:
|
||||
extra_args: list = [],
|
||||
compile_vmfb: bool = True,
|
||||
device_idx: int = None,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
if self.mlir_module is not None:
|
||||
if not os.path.isfile(mlir_module):
|
||||
print(
|
||||
"Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
|
||||
)
|
||||
self.compile_str = True
|
||||
else:
|
||||
self.compile_str = False
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
self.device_idx = device_idx
|
||||
self.rt_flags = rt_flags
|
||||
|
||||
if check_device_drivers(self.device):
|
||||
print(device_driver_info(self.device))
|
||||
@@ -91,13 +101,17 @@ class SharkRunner:
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
device_idx=self.device_idx,
|
||||
rt_flags=self.rt_flags,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
self.iree_compilation_module = params["vmfb"]
|
||||
self.iree_config = params["config"]
|
||||
self.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
|
||||
def run(self, function_name, inputs: tuple, send_to_host=False):
|
||||
def run(
|
||||
self, function_name, inputs: tuple, send_to_host=False, device=None
|
||||
):
|
||||
return get_results(
|
||||
self.iree_compilation_module,
|
||||
function_name,
|
||||
@@ -105,6 +119,7 @@ class SharkRunner:
|
||||
self.iree_config,
|
||||
self.mlir_dialect,
|
||||
send_to_host,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
from shark.parser import shark_args
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.backward_makefx import MakeFxModule
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
@@ -84,6 +84,12 @@ class SharkTrainer:
|
||||
training=True,
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
mlir_module = save_mlir(
|
||||
mlir_module,
|
||||
model_name="shark_model",
|
||||
frontend="torch",
|
||||
mlir_dialect=mlir_type,
|
||||
)
|
||||
self.shark_runner = SharkRunner(
|
||||
mlir_module,
|
||||
self.device,
|
||||
|
||||
@@ -1,24 +1,6 @@
|
||||
resnet50,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
albert-base-v2,stablehlo,tf,1e-2,1e-2,default,None,False,False,False,"",""
|
||||
roberta-base,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
bert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","enabled_windows"
|
||||
camembert-base,stablehlo,tf,1e-2,1e-3,default,None,True,True,True,"",""
|
||||
dbmdz/convbert-base-turkish-cased,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/iree-org/iree/issues/9971",""
|
||||
distilbert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/convnext-tiny-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342","macos"
|
||||
funnel-transformer/small,stablehlo,tf,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/201",""
|
||||
google/electra-small-discriminator,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"Fails during iree-compile","macos"
|
||||
google/vit-base-patch16-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,False,"",""
|
||||
microsoft/MiniLM-L12-H384-uncased,stablehlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
|
||||
microsoft/layoutlm-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
microsoft/mpnet-base,stablehlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
|
||||
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-large-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
|
||||
@@ -32,14 +14,8 @@ resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
|
||||
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
|
||||
efficientnet-v2-s,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b0,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
gpt2,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
|
||||
t5-base,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
|
||||
t5-large,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
|
||||
|
@@ -1,4 +1,3 @@
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -1,30 +1,25 @@
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
from shark.iree_utils._common import (
|
||||
check_device_drivers,
|
||||
device_driver_info,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
OPT_MODEL = "opt-1.3b"
|
||||
OPT_FS_NAME = "opt-1_3b"
|
||||
MAX_SEQUENCE_LENGTH = 128
|
||||
MAX_NEW_TOKENS = 60
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
def create_module(model_name, tokenizer, device):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
|
||||
def create_module(model_name, tokenizer, device, args):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(
|
||||
model_name, allow_mismatched_sizes=True
|
||||
)
|
||||
opt_base_model.eval()
|
||||
opt_model = OPTForCausalLMModel(opt_base_model)
|
||||
encoded_inputs = tokenizer(
|
||||
"What is the meaning of life?",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
max_length=args.max_seq_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
@@ -33,32 +28,34 @@ def create_module(model_name, tokenizer, device):
|
||||
)
|
||||
# np.save("model_inputs_0.npy", inputs[0])
|
||||
# np.save("model_inputs_1.npy", inputs[1])
|
||||
opt_fs_name = "-".join(
|
||||
"_".join(args.model_name.split("/")[1].split("-")).split(".")
|
||||
)
|
||||
|
||||
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
|
||||
mlir_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch.mlir"
|
||||
if os.path.isfile(mlir_path):
|
||||
with open(mlir_path, "r") as f:
|
||||
model_mlir = f.read()
|
||||
print(f"Loaded .mlir from {mlir_path}")
|
||||
print(f"Found .mlir from {mlir_path}")
|
||||
else:
|
||||
(model_mlir, func_name) = import_with_fx(
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=OPT_FS_NAME,
|
||||
model_name=opt_fs_name,
|
||||
return_str=True,
|
||||
)
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
del model_mlir
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
mlir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=False,
|
||||
)
|
||||
|
||||
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu"
|
||||
shark_module.save_module(module_name=vmfb_name, debug=False)
|
||||
vmfb_path = vmfb_name + ".vmfb"
|
||||
return vmfb_path
|
||||
@@ -72,11 +69,11 @@ def shouldStop(tokens):
|
||||
return False
|
||||
|
||||
|
||||
def generate_new_token(shark_model, tokenizer, new_text):
|
||||
def generate_new_token(shark_module, tokenizer, new_text, max_seq_len: int):
|
||||
model_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
max_length=max_seq_len,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
@@ -85,7 +82,7 @@ def generate_new_token(shark_model, tokenizer, new_text):
|
||||
model_inputs["attention_mask"],
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
output = shark_model("forward", inputs)
|
||||
output = shark_module("forward", inputs)
|
||||
output = torch.FloatTensor(output[0])
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
@@ -105,39 +102,96 @@ def generate_new_token(shark_model, tokenizer, new_text):
|
||||
return ret_dict
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max-seq-len", type=int, default=32)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
help="Model name",
|
||||
type=str,
|
||||
choices=[
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-350m",
|
||||
"facebook/opt-1.3b",
|
||||
"facebook/opt-6.7b",
|
||||
"mit-han-lab/opt-125m-smoothquant",
|
||||
"mit-han-lab/opt-1.3b-smoothquant",
|
||||
"mit-han-lab/opt-2.7b-smoothquant",
|
||||
"mit-han-lab/opt-6.7b-smoothquant",
|
||||
"mit-han-lab/opt-13b-smoothquant",
|
||||
],
|
||||
default="facebook/opt-1.3b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recompile",
|
||||
help="If set, recompiles MLIR -> .vmfb",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin-path",
|
||||
help="path to executable plugin",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("args={}".format(args))
|
||||
return args
|
||||
|
||||
|
||||
def generate_tokens(
|
||||
opt_shark_module: "SharkInference",
|
||||
tokenizer,
|
||||
input_text: str,
|
||||
max_output_len: int,
|
||||
print_intermediate_results: True,
|
||||
) -> Iterable[str]:
|
||||
words_list = []
|
||||
new_text = input_text
|
||||
try:
|
||||
for _ in range(max_output_len):
|
||||
generated_token_op = generate_new_token(
|
||||
opt_shark_module, tokenizer, new_text, max_output_len
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
if generated_token_op["stop_generation"]:
|
||||
break
|
||||
if print_intermediate_results:
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text += detok
|
||||
except KeyboardInterrupt as e:
|
||||
print("Exiting token generation.")
|
||||
return words_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/" + OPT_MODEL, use_fast=False
|
||||
args = parse_args()
|
||||
if "smoothquant" in args.model_name:
|
||||
token_model_name = f"facebook/opt-{args.model_name.split('-')[3]}"
|
||||
else:
|
||||
token_model_name = args.model_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
|
||||
opt_fs_name = "-".join(
|
||||
"_".join(args.model_name.split("/")[1].split("-")).split(".")
|
||||
)
|
||||
vmfb_path = (
|
||||
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-task.vmfb"
|
||||
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
|
||||
if args.plugin_path is not None:
|
||||
rt_flags = [f"--executable_plugin={args.plugin_path}"]
|
||||
else:
|
||||
rt_flags = []
|
||||
opt_shark_module = SharkInference(
|
||||
mlir_module=None, device="cpu-task", rt_flags=rt_flags
|
||||
)
|
||||
opt_shark_module = SharkInference(mlir_module=None, device="cpu-task")
|
||||
if os.path.isfile(vmfb_path):
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-task")
|
||||
vmfb_path = create_module(args.model_name, tokenizer, "cpu-task", args)
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence to complete:")
|
||||
new_text_init = new_text
|
||||
words_list = []
|
||||
|
||||
for i in range(MAX_NEW_TOKENS):
|
||||
generated_token_op = generate_new_token(
|
||||
opt_shark_module, tokenizer, new_text
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
if stop_generation:
|
||||
break
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text = new_text + detok
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
break
|
||||
input_text = input("Give me a sentence to complete:")
|
||||
generate_tokens(
|
||||
opt_shark_module, tokenizer, input_text, args.max_seq_len
|
||||
)
|
||||
|
||||
74
tank/examples/opt/opt_causallm_samples.py
Normal file
74
tank/examples/opt/opt_causallm_samples.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import opt_causallm
|
||||
import opt_util
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max-seq-len", type=int, default=32)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
help="Model name",
|
||||
type=str,
|
||||
choices=[
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-350m",
|
||||
"facebook/opt-1.3b",
|
||||
"facebook/opt-6.7b",
|
||||
],
|
||||
default="facebook/opt-1.3b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recompile",
|
||||
help="If set, recompiles MLIR -> .vmfb",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin-path",
|
||||
help="path to executable plugin",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("args={}".format(args))
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
|
||||
opt_fs_name = "-".join(
|
||||
"_".join(args.model_name.split("/")[1].split("-")).split(".")
|
||||
)
|
||||
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
|
||||
if args.plugin_path is not None:
|
||||
rt_flags = [f"--executable_plugin={args.plugin_path}"]
|
||||
else:
|
||||
rt_flags = []
|
||||
opt_shark_module = SharkInference(
|
||||
mlir_module=None, device="cpu-task", rt_flags=rt_flags
|
||||
)
|
||||
if os.path.isfile(vmfb_path):
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
vmfb_path = opt_causallm.create_module(
|
||||
args.model_name, tokenizer, "cpu-task", args
|
||||
)
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
|
||||
for prompt in opt_util.PROMPTS:
|
||||
print("\n\nprompt: {}".format(prompt))
|
||||
response = opt_causallm.generate_tokens(
|
||||
opt_shark_module,
|
||||
tokenizer,
|
||||
prompt,
|
||||
args.max_seq_len,
|
||||
print_intermediate_results=False,
|
||||
)
|
||||
print("response: {}".format("".join(response)))
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
OPT_MODEL = "facebook/opt-1.3b"
|
||||
@@ -57,9 +57,10 @@ class OPTModuleTester:
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(mlir_module)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
del mlir_module
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
mlir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=self.benchmark,
|
||||
|
||||
@@ -18,14 +18,17 @@ import collections
|
||||
import json
|
||||
import os
|
||||
import psutil
|
||||
import resource
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
from opt_util import PROMPTS
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
from shark.parser import shark_args
|
||||
import iree.compiler as ireec
|
||||
|
||||
DEVICE = "cpu"
|
||||
PLATFORM_SHARK = "shark"
|
||||
@@ -42,19 +45,6 @@ REPORT_LOAD_VIRTUAL_MEMORY_MB = "load_virtual_MB"
|
||||
REPORT_RUN_PHYSICAL_MEMORY_MB = "run_physical_MB"
|
||||
REPORT_RUN_VIRTUAL_MEMORY_MB = "run_virtual_MB"
|
||||
|
||||
PROMPTS = [
|
||||
"What is the meaning of life?",
|
||||
"Tell me something you don't know.",
|
||||
"What does Xilinx do?",
|
||||
"What is the mass of earth?",
|
||||
"What is a poem?",
|
||||
"What is recursion?",
|
||||
"Tell me a one line joke.",
|
||||
"Who is Gilgamesh?",
|
||||
"Tell me something about cryptocurrency.",
|
||||
"How did it all begin?",
|
||||
]
|
||||
|
||||
ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])
|
||||
|
||||
|
||||
@@ -64,14 +54,15 @@ def get_memory_info():
|
||||
return process.memory_info()
|
||||
|
||||
|
||||
def create_vmfb_module(
|
||||
def import_mlir_module(
|
||||
model_name: str,
|
||||
tokenizer,
|
||||
device: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(
|
||||
model_name, ignore_mismatched_sizes=True
|
||||
)
|
||||
opt_base_model.eval()
|
||||
opt_model = OPTForCausalLMModel(opt_base_model)
|
||||
encoded_inputs = tokenizer(
|
||||
@@ -88,6 +79,27 @@ def create_vmfb_module(
|
||||
# np.save("model_inputs_0.npy", inputs[0])
|
||||
# np.save("model_inputs_1.npy", inputs[1])
|
||||
|
||||
opt_fs_name = get_opt_fs_name(model_name)
|
||||
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
|
||||
(model_mlir, func_name) = import_with_fx(
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=opt_fs_name,
|
||||
return_str=True,
|
||||
)
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
|
||||
def create_vmfb_module(
|
||||
model_name: str,
|
||||
tokenizer,
|
||||
device: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
):
|
||||
opt_fs_name = get_opt_fs_name(model_name)
|
||||
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
|
||||
# If MLIR has already been loaded and recompilation is not requested, use
|
||||
@@ -97,49 +109,49 @@ def create_vmfb_module(
|
||||
# compilation time can be correctly measured only when MLIR has already been
|
||||
# loaded.
|
||||
assert not recompile_shark or has_mlir
|
||||
if has_mlir:
|
||||
with open(mlir_path, "r") as f:
|
||||
model_mlir = f.read()
|
||||
print(f"Loaded .mlir from {mlir_path}")
|
||||
else:
|
||||
(model_mlir, func_name) = import_with_fx(
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=opt_fs_name,
|
||||
return_str=True,
|
||||
if not has_mlir:
|
||||
import_mlir_module(
|
||||
model_name,
|
||||
tokenizer,
|
||||
device,
|
||||
max_seq_len,
|
||||
)
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
mlir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=False,
|
||||
rt_flags=[],
|
||||
)
|
||||
|
||||
vmfb_name = (
|
||||
f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels"
|
||||
)
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}"
|
||||
shark_module.save_module(module_name=vmfb_name)
|
||||
vmfb_path = vmfb_name + ".vmfb"
|
||||
return vmfb_path
|
||||
|
||||
|
||||
def load_shark_model(
|
||||
model_name: str, max_seq_len: int, recompile_shark: bool
|
||||
model_name: str,
|
||||
token_model_name: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
plugin_path: str = [],
|
||||
) -> ModelWrapper:
|
||||
opt_fs_name = get_opt_fs_name(model_name)
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb"
|
||||
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
|
||||
if recompile_shark or not os.path.isfile(vmfb_name):
|
||||
print(f"vmfb not found. compiling and saving to {vmfb_name}")
|
||||
create_vmfb_module(
|
||||
model_name, tokenizer, DEVICE, max_seq_len, recompile_shark
|
||||
)
|
||||
shark_module = SharkInference(mlir_module=None, device="cpu-task")
|
||||
if plugin_path is not None:
|
||||
rt_flags = [f"--executable_plugin={plugin_path}"]
|
||||
else:
|
||||
rt_flags = []
|
||||
shark_module = SharkInference(
|
||||
mlir_module=None, device="cpu-task", rt_flags=rt_flags
|
||||
)
|
||||
shark_module.load_module(vmfb_name)
|
||||
return ModelWrapper(model=shark_module, tokenizer=tokenizer)
|
||||
|
||||
@@ -149,10 +161,12 @@ def run_shark_model(model_wrapper: ModelWrapper, tokens):
|
||||
return model_wrapper.model("forward", tokens)
|
||||
|
||||
|
||||
def load_huggingface_model(model_name: str) -> ModelWrapper:
|
||||
def load_huggingface_model(
|
||||
model_name: str, token_model_name: str
|
||||
) -> ModelWrapper:
|
||||
return ModelWrapper(
|
||||
model=OPTForCausalLM.from_pretrained(model_name),
|
||||
tokenizer=AutoTokenizer.from_pretrained(model_name),
|
||||
tokenizer=AutoTokenizer.from_pretrained(token_model_name),
|
||||
)
|
||||
|
||||
|
||||
@@ -168,11 +182,14 @@ def save_json(data, filename):
|
||||
|
||||
|
||||
def collect_huggingface_logits(
|
||||
model_name: str, max_seq_len: int, to_save_json: bool
|
||||
model_name: str,
|
||||
token_model_name: str,
|
||||
max_seq_len: int,
|
||||
to_save_json: bool,
|
||||
) -> Tuple[float, float]:
|
||||
# Load
|
||||
t0 = time.time()
|
||||
model_wrapper = load_huggingface_model(model_name)
|
||||
model_wrapper = load_huggingface_model(model_name, token_model_name)
|
||||
load_time = time.time() - t0
|
||||
print("--- Took {} seconds to load Huggingface.".format(load_time))
|
||||
load_memory_info = get_memory_info()
|
||||
@@ -216,13 +233,17 @@ def collect_huggingface_logits(
|
||||
|
||||
def collect_shark_logits(
|
||||
model_name: str,
|
||||
token_model_name: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
to_save_json: bool,
|
||||
plugin_path: str,
|
||||
) -> Tuple[float, float]:
|
||||
# Load
|
||||
t0 = time.time()
|
||||
model_wrapper = load_shark_model(model_name, max_seq_len, recompile_shark)
|
||||
model_wrapper = load_shark_model(
|
||||
model_name, token_model_name, max_seq_len, recompile_shark, plugin_path
|
||||
)
|
||||
load_time = time.time() - t0
|
||||
print("--- Took {} seconds to load Shark.".format(load_time))
|
||||
load_memory_info = get_memory_info()
|
||||
@@ -303,6 +324,11 @@ def parse_args():
|
||||
"facebook/opt-350m",
|
||||
"facebook/opt-1.3b",
|
||||
"facebook/opt-6.7b",
|
||||
"mit-han-lab/opt-125m-smoothquant",
|
||||
"mit-han-lab/opt-1.3b-smoothquant",
|
||||
"mit-han-lab/opt-2.7b-smoothquant",
|
||||
"mit-han-lab/opt-6.7b-smoothquant",
|
||||
"mit-han-lab/opt-13b-smoothquant",
|
||||
],
|
||||
default="facebook/opt-1.3b",
|
||||
)
|
||||
@@ -319,6 +345,18 @@ def parse_args():
|
||||
choices=[PLATFORM_SHARK, PLATFORM_HUGGINGFACE],
|
||||
default=PLATFORM_SHARK,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin-path",
|
||||
help="path to executable plugin",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token-model-name",
|
||||
help="HF ID to create tokenizer.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("args={}".format(args))
|
||||
return args
|
||||
@@ -326,16 +364,28 @@ def parse_args():
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
if args.token_model_name == None:
|
||||
if "smoothquant" in args.model_name:
|
||||
args.token_model_name = (
|
||||
f"facebook/opt-{args.model_name.split('-')[3]}"
|
||||
)
|
||||
else:
|
||||
args.token_model_name = args.model_name
|
||||
if args.platform == PLATFORM_SHARK:
|
||||
shark_report = collect_shark_logits(
|
||||
args.model_name,
|
||||
args.token_model_name,
|
||||
args.max_seq_len,
|
||||
args.recompile_shark,
|
||||
args.save_json,
|
||||
args.plugin_path,
|
||||
)
|
||||
print("# Summary: {}".format(json.dumps(shark_report)))
|
||||
else:
|
||||
huggingface_report = collect_huggingface_logits(
|
||||
args.model_name, args.max_seq_len, args.save_json
|
||||
args.model_name,
|
||||
args.token_model_name,
|
||||
args.max_seq_len,
|
||||
args.save_json,
|
||||
)
|
||||
print("# Summary: {}".format(json.dumps(huggingface_report)))
|
||||
|
||||
12
tank/examples/opt/opt_util.py
Normal file
12
tank/examples/opt/opt_util.py
Normal file
@@ -0,0 +1,12 @@
|
||||
PROMPTS = [
|
||||
"What is the meaning of life?",
|
||||
"Tell me something you don't know.",
|
||||
"What does Xilinx do?",
|
||||
"What is the mass of earth?",
|
||||
"What is a poem?",
|
||||
"What is recursion?",
|
||||
"Tell me a one line joke.",
|
||||
"Who is Gilgamesh?",
|
||||
"Tell me something about cryptocurrency.",
|
||||
"How did it all begin?",
|
||||
]
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import torch
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
|
||||
model_name = "facebook/opt-1.3b"
|
||||
@@ -25,11 +25,13 @@ inputs = (
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
debug=True,
|
||||
model_name=model_name.split("/")[1],
|
||||
save_dir=".",
|
||||
)
|
||||
|
||||
mlir_module = save_mlir(
|
||||
mlir_module,
|
||||
model_name=model_name.split("/")[1],
|
||||
frontend="torch",
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device="cpu-sync",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
from tank.test_utils import get_valid_test_params, shark_test_name_func
|
||||
|
||||
@@ -44,7 +44,7 @@ class TapasBaseModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("cuda"), reason=device_driver_info("gpu")
|
||||
check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
)
|
||||
def test_module_static_cuda(self):
|
||||
dynamic = False
|
||||
|
||||
@@ -36,7 +36,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
get_hf_img_cls_model,
|
||||
get_fp16_model,
|
||||
)
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
|
||||
with open(torch_model_list) as csvfile:
|
||||
torch_reader = csv.reader(csvfile, delimiter=",")
|
||||
@@ -130,133 +130,6 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
)
|
||||
|
||||
|
||||
def save_tf_model(tf_model_list, local_tank_cache, import_args):
|
||||
from tank.model_utils_tf import (
|
||||
get_causal_image_model,
|
||||
get_masked_lm_model,
|
||||
get_causal_lm_model,
|
||||
get_keras_model,
|
||||
get_TFhf_model,
|
||||
get_tfhf_seq2seq_model,
|
||||
)
|
||||
import os
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
import tensorflow as tf
|
||||
|
||||
visible_default = tf.config.list_physical_devices("GPU")
|
||||
try:
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
visible_devices = tf.config.get_visible_devices()
|
||||
for device in visible_devices:
|
||||
assert device.device_type != "GPU"
|
||||
except:
|
||||
# Invalid device or cannot modify virtual devices once initialized.
|
||||
pass
|
||||
|
||||
with open(tf_model_list) as csvfile:
|
||||
tf_reader = csv.reader(csvfile, delimiter=",")
|
||||
fields = next(tf_reader)
|
||||
for row in tf_reader:
|
||||
tf_model_name = row[0]
|
||||
model_type = row[1]
|
||||
|
||||
model = None
|
||||
input = None
|
||||
print(f"Generating artifacts for model {tf_model_name}")
|
||||
if model_type == "hf":
|
||||
model, input, _ = get_masked_lm_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "img":
|
||||
model, input, _ = get_causal_image_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "keras":
|
||||
model, input, _ = get_keras_model(tf_model_name, import_args)
|
||||
elif model_type == "TFhf":
|
||||
model, input, _ = get_TFhf_model(tf_model_name, import_args)
|
||||
elif model_type == "tfhf_seq2seq":
|
||||
model, input, _ = get_tfhf_seq2seq_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf_causallm":
|
||||
model, input, _ = get_causal_lm_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
if import_args["batch_size"] != 1:
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache,
|
||||
str(tf_model_name)
|
||||
+ "_tf"
|
||||
+ f"_BS{str(import_args['batch_size'])}",
|
||||
)
|
||||
else:
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache, str(tf_model_name) + "_tf"
|
||||
)
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
inputs=input,
|
||||
frontend="tf",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
dir=tf_model_dir,
|
||||
model_name=tf_model_name,
|
||||
)
|
||||
|
||||
|
||||
def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
|
||||
from shark.tflite_utils import TFLitePreprocessor
|
||||
|
||||
with open(tflite_model_list) as csvfile:
|
||||
tflite_reader = csv.reader(csvfile, delimiter=",")
|
||||
for row in tflite_reader:
|
||||
print("\n")
|
||||
tflite_model_name = row[0]
|
||||
tflite_model_link = row[1]
|
||||
print("tflite_model_name", tflite_model_name)
|
||||
print("tflite_model_link", tflite_model_link)
|
||||
tflite_model_name_dir = os.path.join(
|
||||
local_tank_cache, str(tflite_model_name) + "_tflite"
|
||||
)
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
|
||||
# Preprocess to get SharkImporter input import_args
|
||||
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
|
||||
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
|
||||
inputs = tflite_preprocessor.get_inputs()
|
||||
tflite_interpreter = tflite_preprocessor.get_interpreter()
|
||||
|
||||
# Use SharkImporter to get SharkInference input import_args
|
||||
my_shark_importer = SharkImporter(
|
||||
module=tflite_interpreter,
|
||||
inputs=inputs,
|
||||
frontend="tflite",
|
||||
raw_model_file=raw_model_file_path,
|
||||
)
|
||||
my_shark_importer.import_debug(
|
||||
dir=tflite_model_name_dir,
|
||||
model_name=tflite_model_name,
|
||||
func_name="main",
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(
|
||||
tflite_model_name_dir,
|
||||
tflite_model_name + "_tflite" + ".mlir",
|
||||
)
|
||||
)
|
||||
np.save(
|
||||
os.path.join(tflite_model_name_dir, "hash"),
|
||||
np.array(mlir_hash),
|
||||
)
|
||||
|
||||
|
||||
def check_requirements(frontend):
|
||||
import importlib
|
||||
|
||||
@@ -265,10 +138,6 @@ def check_requirements(frontend):
|
||||
tv_spec = importlib.util.find_spec("torchvision")
|
||||
has_pkgs = tv_spec is not None
|
||||
|
||||
elif frontend in ["tensorflow", "tf"]:
|
||||
tf_spec = importlib.util.find_spec("tensorflow")
|
||||
has_pkgs = tf_spec is not None
|
||||
|
||||
return has_pkgs
|
||||
|
||||
|
||||
@@ -287,27 +156,11 @@ def gen_shark_files(modelname, frontend, tank_dir, importer_args):
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tf_model_list.csv"
|
||||
)
|
||||
custom_model_csv = tempfile.NamedTemporaryFile(
|
||||
dir=os.path.dirname(__file__),
|
||||
delete=True,
|
||||
)
|
||||
# Create a temporary .csv with only the desired entry.
|
||||
if frontend == "tf":
|
||||
with open(tf_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_tf_model(custom_model_csv.name, tank_dir, import_args)
|
||||
|
||||
elif frontend == "torch":
|
||||
if frontend == "torch":
|
||||
with open(torch_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
@@ -341,18 +194,6 @@ if __name__ == "__main__":
|
||||
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--tf_model_csv",
|
||||
# type=lambda x: is_valid_file(x),
|
||||
# default="./tank/tf_model_list.csv",
|
||||
# help="Contains the file with tf model name and args.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--tflite_model_csv",
|
||||
# type=lambda x: is_valid_file(x),
|
||||
# default="./tank/tflite/tflite_model_list.csv",
|
||||
# help="Contains the file with tf model name and args.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--ci_tank_dir",
|
||||
# type=bool,
|
||||
# default=False,
|
||||
@@ -369,11 +210,5 @@ if __name__ == "__main__":
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
|
||||
tflite_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
|
||||
)
|
||||
|
||||
save_torch_model(torch_model_csv, WORKDIR, import_args)
|
||||
# save_tf_model(tf_model_csv, WORKDIR, import_args)
|
||||
# save_tflite_model(tflite_model_csv, WORKDIR, import_args)
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
model_name, model_type
|
||||
albert-base-v2,hf
|
||||
bert-base-uncased,hf
|
||||
camembert-base,hf
|
||||
dbmdz/convbert-base-turkish-cased,hf
|
||||
distilbert-base-uncased,hf
|
||||
google/electra-small-discriminator,hf
|
||||
funnel-transformer/small,hf
|
||||
microsoft/layoutlm-base-uncased,hf
|
||||
google/mobilebert-uncased,hf
|
||||
microsoft/mpnet-base,hf
|
||||
roberta-base,hf
|
||||
resnet50,keras
|
||||
xlm-roberta-base,hf
|
||||
microsoft/MiniLM-L12-H384-uncased,TFhf
|
||||
funnel-transformer/small,hf
|
||||
microsoft/mpnet-base,hf
|
||||
facebook/convnext-tiny-224,img
|
||||
google/vit-base-patch16-224,img
|
||||
efficientnet-v2-s,keras
|
||||
bert-large-uncased,hf
|
||||
t5-base,tfhf_seq2seq
|
||||
t5-large,tfhf_seq2seq
|
||||
efficientnet_b0,keras
|
||||
efficientnet_b7,keras
|
||||
gpt2,hf_causallm
|
||||
t5-base,tfhf_seq2seq
|
||||
t5-large,tfhf_seq2seq
|
||||
|
Reference in New Issue
Block a user