mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
156 Commits
20230418.6
...
ean-opt-tu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd40a3fafe | ||
|
|
bf6fcc353a | ||
|
|
918eba6524 | ||
|
|
d61b6641fb | ||
|
|
88cc2423cc | ||
|
|
ccf944c1bd | ||
|
|
0def74f520 | ||
|
|
3fb72e192e | ||
|
|
855435ee24 | ||
|
|
6f9f868fc0 | ||
|
|
fb865f1b99 | ||
|
|
3e5c50f07b | ||
|
|
a544f30a8f | ||
|
|
1fe56d460a | ||
|
|
fafd713141 | ||
|
|
015d0132c3 | ||
|
|
20ddd96ef7 | ||
|
|
ee33cfd2d1 | ||
|
|
a3cba21d5b | ||
|
|
a7b6ec4095 | ||
|
|
d80b087d95 | ||
|
|
297a209608 | ||
|
|
b204113563 | ||
|
|
f60ab1f4fa | ||
|
|
b203779462 | ||
|
|
38570a9bbb | ||
|
|
a5c882f296 | ||
|
|
eb6d11cfed | ||
|
|
46184a81ac | ||
|
|
149165a2f0 | ||
|
|
bec82a665f | ||
|
|
9551490341 | ||
|
|
49b3ecdbca | ||
|
|
f53e3594c3 | ||
|
|
5562d1dfda | ||
|
|
c7b0c2961e | ||
|
|
44273b0791 | ||
|
|
0a4c8fcb3e | ||
|
|
2fec3c8169 | ||
|
|
5e7d5930dd | ||
|
|
b6dbd20250 | ||
|
|
34f1295349 | ||
|
|
1980d7b2c3 | ||
|
|
2cfacc5051 | ||
|
|
436f58ddc4 | ||
|
|
6b29bd17c8 | ||
|
|
2c3485ca3e | ||
|
|
f206ecc635 | ||
|
|
a187e05ae6 | ||
|
|
8c21960486 | ||
|
|
be62fce676 | ||
|
|
f23b778a6c | ||
|
|
436edf900d | ||
|
|
ed58c2553f | ||
|
|
f2ca58e844 | ||
|
|
1dbcc736eb | ||
|
|
a83808ddc5 | ||
|
|
a07fe80530 | ||
|
|
d0ba3ef8fa | ||
|
|
8400529c2c | ||
|
|
7eaee9c242 | ||
|
|
8230eebce5 | ||
|
|
6296ea4be9 | ||
|
|
4151ec3a8f | ||
|
|
a2467e8d43 | ||
|
|
e677178bcc | ||
|
|
7ef1bea953 | ||
|
|
ad89bb1413 | ||
|
|
218ed78c40 | ||
|
|
6046f36ab6 | ||
|
|
5915bf7de3 | ||
|
|
f0a4e59758 | ||
|
|
1ddef26af5 | ||
|
|
ba8eddb12f | ||
|
|
47b346d428 | ||
|
|
1b4f4f5f4d | ||
|
|
73cd7e8320 | ||
|
|
19c0ae3702 | ||
|
|
54e57f7771 | ||
|
|
6d64b8e273 | ||
|
|
a8ea0326f5 | ||
|
|
58e9194553 | ||
|
|
eb360e255d | ||
|
|
a6f88d7f72 | ||
|
|
8e571d165f | ||
|
|
3cddd01b10 | ||
|
|
64c2b2d96b | ||
|
|
f5ce121988 | ||
|
|
991f144598 | ||
|
|
09bea17e59 | ||
|
|
aefcf80b48 | ||
|
|
512235892e | ||
|
|
6602a2f5ba | ||
|
|
20114deea0 | ||
|
|
9acf519078 | ||
|
|
bdf37b5311 | ||
|
|
8ee2ac89f8 | ||
|
|
60cb48be2e | ||
|
|
86a215b063 | ||
|
|
d6e3a9a236 | ||
|
|
a0097a1ead | ||
|
|
a9bae00606 | ||
|
|
4731c1a835 | ||
|
|
4c07e47e8c | ||
|
|
e0cc2871bb | ||
|
|
649f39408b | ||
|
|
c142297d73 | ||
|
|
9e07360b00 | ||
|
|
7b74c86e42 | ||
|
|
fa833f8366 | ||
|
|
fcb059aa38 | ||
|
|
517c670f82 | ||
|
|
59df14f18b | ||
|
|
6c95ac0f37 | ||
|
|
7a4a51ae73 | ||
|
|
d816cc015e | ||
|
|
54ce3d48ca | ||
|
|
0e4a8ca240 | ||
|
|
6ca1298675 | ||
|
|
bbef7a6464 | ||
|
|
cdf2d61d53 | ||
|
|
6c14847d1f | ||
|
|
68ecdd2a73 | ||
|
|
3f4d444d18 | ||
|
|
e473d0375b | ||
|
|
e38d96850f | ||
|
|
fed63dfd4b | ||
|
|
eba4d06405 | ||
|
|
4cfba153d2 | ||
|
|
307c05f38d | ||
|
|
696df349cb | ||
|
|
cb54cb1348 | ||
|
|
9bdb86637d | ||
|
|
fb6f26517f | ||
|
|
aa8ada9da9 | ||
|
|
1db906a373 | ||
|
|
9d1d1617d8 | ||
|
|
7112789cb8 | ||
|
|
d6b8be2849 | ||
|
|
822171277c | ||
|
|
a5ae9d9f02 | ||
|
|
09e3f63d5b | ||
|
|
d60a5a9396 | ||
|
|
90df0ee365 | ||
|
|
133c1bcadd | ||
|
|
caadbe14e9 | ||
|
|
5f5823ccd9 | ||
|
|
d2f7e03b7e | ||
|
|
0b01bbe479 | ||
|
|
25c5fc44ae | ||
|
|
7330729c92 | ||
|
|
ce16cd5431 | ||
|
|
598dc5f79d | ||
|
|
1f8e332cbe | ||
|
|
17b9632659 | ||
|
|
bda92a54ab |
26
.github/workflows/nightly.yml
vendored
26
.github/workflows/nightly.yml
vendored
@@ -50,27 +50,13 @@ jobs:
|
||||
shell: powershell
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
|
||||
python process_skipfiles.py
|
||||
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
|
||||
|
||||
# GHA windows VM OOMs so disable for now
|
||||
#- name: Build and validate the SHARK Runtime package
|
||||
# shell: powershell
|
||||
# run: |
|
||||
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
#- uses: actions/upload-artifact@v2
|
||||
# with:
|
||||
# path: dist/*
|
||||
|
||||
mv ./dist/shark_sd.exe ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
@@ -78,7 +64,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ./dist/*
|
||||
assets_path: ./dist/nodai*
|
||||
#asset_content_type: application/vnd.microsoft.portable-executable
|
||||
|
||||
- name: Publish Release
|
||||
|
||||
6
.github/workflows/test-models.yml
vendored
6
.github/workflows/test-models.yml
vendored
@@ -61,7 +61,6 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
if: matrix.os != '7950x'
|
||||
|
||||
- name: Set Environment Variables
|
||||
if: matrix.os != '7950x'
|
||||
@@ -84,9 +83,6 @@ jobs:
|
||||
#cache-dependency-path: |
|
||||
# **/requirements-importer.txt
|
||||
# **/requirements.txt
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
if: matrix.os == '7950x'
|
||||
|
||||
- name: Install dependencies
|
||||
if: matrix.suite == 'lint'
|
||||
@@ -137,7 +133,7 @@ jobs:
|
||||
export DYLD_LIBRARY_PATH=/usr/local/lib/
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan --update_tank
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
|
||||
210
apps/language_models/scripts/stablelm.py
Normal file
210
apps/language_models/scripts/stablelm.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
StoppingCriteria,
|
||||
)
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def shouldStop(tokens):
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 256
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
def compile_stableLM(
|
||||
model,
|
||||
model_inputs,
|
||||
model_name,
|
||||
model_vmfb_name,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
):
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
vmfb_path = (
|
||||
Path(model_name + f"_{device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
shark_module = get_vmfb_from_path(
|
||||
vmfb_path, device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
if shark_module is not None:
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*model_inputs],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(model_name + ".mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir")
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
class StableLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
|
||||
|
||||
# Initialize a StopOnTokens object
|
||||
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
|
||||
|
||||
def get_tokenizer():
|
||||
model_path = "stabilityai/stablelm-tuned-alpha-3b"
|
||||
tok = AutoTokenizer.from_pretrained(model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
print("Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
|
||||
# sharkStableLM = compile_stableLM
|
||||
# (
|
||||
# None,
|
||||
# tuple([input_ids, attention_mask]),
|
||||
# "stableLM_linalg_f32_seqLen256",
|
||||
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
|
||||
# )
|
||||
def generate(
|
||||
new_text,
|
||||
max_new_tokens,
|
||||
sharkStableLM,
|
||||
tokenizer=None,
|
||||
):
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer()
|
||||
# Construct the input message string for the model by
|
||||
# concatenating the current system message and conversation history
|
||||
# Tokenize the messages string
|
||||
# sharkStableLM = compile_stableLM
|
||||
# (
|
||||
# None,
|
||||
# tuple([input_ids, attention_mask]),
|
||||
# "stableLM_linalg_f32_seqLen256",
|
||||
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
|
||||
# )
|
||||
words_list = []
|
||||
for i in range(max_new_tokens):
|
||||
# numWords = len(new_text.split())
|
||||
# if(numWords>220):
|
||||
# break
|
||||
params = {
|
||||
"new_text": new_text,
|
||||
}
|
||||
generated_token_op = generate_new_token(
|
||||
sharkStableLM, tokenizer, params
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
if stop_generation:
|
||||
break
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text = new_text + detok
|
||||
return words_list
|
||||
|
||||
|
||||
def generate_new_token(shark_model, tokenizer, params):
|
||||
new_text = params["new_text"]
|
||||
model_inputs = tokenizer(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
output = shark_model(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
|
||||
detok = tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
122
apps/language_models/scripts/vicuna.py
Normal file
122
apps/language_models/scripts/vicuna.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from apps.language_models.src.pipelines import vicuna_pipeline as vp
|
||||
from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp
|
||||
import torch
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="vicuna runner",
|
||||
description="runs a vicuna model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
"--first_vicuna_vmfb_path", default=None, help="path to first vicuna vmfb"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--sharded",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model as sharded",
|
||||
)
|
||||
# TODO: sharded config
|
||||
|
||||
parser.add_argument(
|
||||
"--second_vicuna_vmfb_path",
|
||||
default=None,
|
||||
help="path to second vicuna vmfb",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first_vicuna_mlir_path",
|
||||
default=None,
|
||||
help="path to first vicuna mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--second_vicuna_mlir_path",
|
||||
default=None,
|
||||
help="path to second vicuna mlir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cli",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model in cli mode",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
vic = None
|
||||
if not args.sharded:
|
||||
first_vic_mlir_path = (
|
||||
Path(f"first_vicuna_{args.precision}.mlir")
|
||||
if args.first_vicuna_mlir_path is None
|
||||
else Path(args.first_vicuna_mlir_path)
|
||||
)
|
||||
second_vic_mlir_path = (
|
||||
Path(f"second_vicuna_{args.precision}.mlir")
|
||||
if args.second_vicuna_mlir_path is None
|
||||
else Path(args.second_vicuna_mlir_path)
|
||||
)
|
||||
first_vic_vmfb_path = (
|
||||
Path(
|
||||
f"first_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
|
||||
)
|
||||
if args.first_vicuna_vmfb_path is None
|
||||
else Path(args.first_vicuna_vmfb_path)
|
||||
)
|
||||
second_vic_vmfb_path = (
|
||||
Path(
|
||||
f"second_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
|
||||
)
|
||||
if args.second_vicuna_vmfb_path is None
|
||||
else Path(args.second_vicuna_vmfb_path)
|
||||
)
|
||||
vic = vp.Vicuna(
|
||||
"vicuna",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
first_vicuna_mlir_path=first_vic_mlir_path,
|
||||
second_vicuna_mlir_path=second_vic_mlir_path,
|
||||
first_vicuna_vmfb_path=first_vic_vmfb_path,
|
||||
second_vicuna_vmfb_path=second_vic_vmfb_path,
|
||||
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
|
||||
)
|
||||
else:
|
||||
vic = vsp.Vicuna(
|
||||
"vicuna",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
)
|
||||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
import gc
|
||||
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
prompt_history = (
|
||||
prompt_history + "USER:\n" + user_prompt + prologue_prompt
|
||||
)
|
||||
prompt = prompt_history.strip()
|
||||
res_str = vic.generate(prompt, cli=True)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
print(
|
||||
"\n-----\nAssistant: Here's the complete formatted reply:\n",
|
||||
res_str,
|
||||
)
|
||||
prompt_history += f"\n{res_str}\n"
|
||||
22
apps/language_models/src/model_wrappers/falcon_model.py
Normal file
22
apps/language_models/src/model_wrappers/falcon_model.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
|
||||
class FalconModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
}
|
||||
output = self.model(
|
||||
**input_dict,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)[0]
|
||||
return output[:, -1, :]
|
||||
15
apps/language_models/src/model_wrappers/stablelm_model.py
Normal file
15
apps/language_models/src/model_wrappers/stablelm_model.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
class StableLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
261
apps/language_models/src/model_wrappers/vicuna_model.py
Normal file
261
apps/language_models/src/model_wrappers/vicuna_model.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
class FirstVicuna(torch.nn.Module):
|
||||
def __init__(self, model_path):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
|
||||
def forward(self, input_ids):
|
||||
op = self.model(input_ids=input_ids, use_cache=True)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna(torch.nn.Module):
|
||||
def __init__(self, model_path):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
i0,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
):
|
||||
# input_ids = input_tuple[0]
|
||||
# input_tuple = torch.unbind(pkv, dim=0)
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class CombinedModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
first_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
second_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
):
|
||||
super().__init__()
|
||||
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
|
||||
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
|
||||
|
||||
def forward(self, input_ids):
|
||||
first_output = self.first_vicuna(input_ids=input_ids, use_cache=True)
|
||||
logits = first_output[0]
|
||||
pkv = first_output[1:]
|
||||
|
||||
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
|
||||
token = token.to(torch.int64).reshape([1, 1])
|
||||
secondVicunaInput = (token,) + tuple(pkv)
|
||||
second_output = self.second_vicuna(secondVicunaInput)
|
||||
return second_output
|
||||
178
apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Normal file
178
apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import torch
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class SecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class CompiledFirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CompiledSecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers0, layers1):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert len(layers0) == len(model.model.layers)
|
||||
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers0 = layers0
|
||||
self.layers1 = layers1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
is_first=True,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
if is_first:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers0
|
||||
)
|
||||
return self.model.forward(input_ids, attention_mask=attention_mask)
|
||||
else:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers1
|
||||
)
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
41
apps/language_models/src/pipelines/SharkLLMBase.py
Normal file
41
apps/language_models/src/pipelines/SharkLLMBase.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class SharkLLMBase(ABC):
|
||||
def __init__(
|
||||
self, model_name, hf_model_path=None, max_num_tokens=512
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.hf_model_path = hf_model_path
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.shark_model = None
|
||||
self.device = "cpu"
|
||||
self.precision = "fp32"
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def compile(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate(self, prompt):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_new_token(self, params):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_tokenizer(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_src_model(self):
|
||||
pass
|
||||
|
||||
def load_init_from_config(self):
|
||||
pass
|
||||
512
apps/language_models/src/pipelines/falcon_pipeline.py
Normal file
512
apps/language_models/src/pipelines/falcon_pipeline.py
Normal file
@@ -0,0 +1,512 @@
|
||||
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stdout
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers.generation import (
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
import copy
|
||||
|
||||
import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="falcon runner",
|
||||
description="runs a falcon model",
|
||||
)
|
||||
|
||||
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
"--falcon_vmfb_path", default=None, help="path to falcon's vmfb"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--falcon_mlir_path",
|
||||
default=None,
|
||||
help="path to falcon's mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_precompiled_model",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="use the precompiled vmfb",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cli",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model in cli mode",
|
||||
)
|
||||
|
||||
|
||||
class Falcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path,
|
||||
max_num_tokens=150,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.falcon_vmfb_path = falcon_vmfb_path
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
self.src_model = self.get_src_model()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, trust_remote_code=True
|
||||
)
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token_id = 11
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
return falcon_model
|
||||
|
||||
def compile_falcon(self):
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ "_"
|
||||
+ self.device
|
||||
+ ".vmfb",
|
||||
self.falcon_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path, self.device, "linalg"
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
|
||||
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
compilation_input_ids = torch.randint(
|
||||
low=1, high=10000, size=(1, 100)
|
||||
)
|
||||
compilation_attention_mask = torch.ones(
|
||||
1, 100, dtype=torch.int64
|
||||
)
|
||||
falconCompileInput = (
|
||||
compilation_input_ids,
|
||||
compilation_attention_mask,
|
||||
)
|
||||
model = FalconModel(self.src_model)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
falconCompileInput,
|
||||
is_f16=self.precision == "fp16",
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del model
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*falconCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
print(f"[DEBUG] writing mlir to file")
|
||||
with open(f"{self.model_name}.mlir", "wb") as f_:
|
||||
with redirect_stdout(f_):
|
||||
print(module.operation.get_asm())
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
self.falcon_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-spirv-index-bits=64",
|
||||
],
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile(self):
|
||||
falcon_shark_model = self.compile_falcon()
|
||||
return falcon_shark_model
|
||||
|
||||
def generate(self, prompt):
|
||||
model_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.max_padding_length,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
model_inputs["prompt_text"] = prompt
|
||||
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs.get("attention_mask", None)
|
||||
|
||||
# Allow empty prompts
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
in_b = 1
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
|
||||
generate_kwargs = {
|
||||
"max_length": self.max_num_tokens,
|
||||
"do_sample": True,
|
||||
"top_k": 10,
|
||||
"num_return_sequences": 1,
|
||||
"eos_token_id": 11,
|
||||
}
|
||||
generate_kwargs["input_ids"] = input_ids
|
||||
generate_kwargs["attention_mask"] = attention_mask
|
||||
generation_config_ = GenerationConfig.from_model_config(
|
||||
self.src_model.config
|
||||
)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
model_kwargs,
|
||||
) = self.src_model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs[
|
||||
"output_hidden_states"
|
||||
] = generation_config.output_hidden_states
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
self.logits_processor = self.src_model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids.shape[-1],
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
self.stopping_criteria = self.src_model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
self.logits_warper = self.src_model._get_logits_warper(
|
||||
generation_config
|
||||
)
|
||||
|
||||
(
|
||||
self.input_ids,
|
||||
self.model_kwargs,
|
||||
) = self.src_model._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id) if eos_token_id is not None else None
|
||||
)
|
||||
|
||||
self.pad_token_id = generation_config.pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
output_attentions = generation_config.output_attentions # False
|
||||
output_hidden_states = generation_config.output_hidden_states # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
self.scores = (
|
||||
() if (return_dict_in_generate and output_scores) else None
|
||||
)
|
||||
decoder_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
cross_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
decoder_hidden_states = (
|
||||
() if (return_dict_in_generate and output_hidden_states) else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
self.unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
|
||||
all_text = prompt
|
||||
|
||||
for i in range(self.max_num_tokens - 1):
|
||||
next_token = self.generate_new_token()
|
||||
new_word = self.tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
all_text = all_text + new_word
|
||||
|
||||
print(f"{new_word}", end="", flush=True)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if self.eos_token_id_tensor is not None:
|
||||
self.unfinished_sequences = self.unfinished_sequences.mul(
|
||||
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
|
||||
.ne(self.eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if (
|
||||
self.unfinished_sequences.max() == 0
|
||||
or self.stopping_criteria(input_ids, self.scores)
|
||||
):
|
||||
break
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return all_text
|
||||
|
||||
def generate_new_token(self):
|
||||
model_inputs = self.src_model.prepare_inputs_for_generation(
|
||||
self.input_ids, **self.model_kwargs
|
||||
)
|
||||
outputs = torch.from_numpy(
|
||||
self.shark_model(
|
||||
"forward",
|
||||
(model_inputs["input_ids"], model_inputs["attention_mask"]),
|
||||
)
|
||||
)
|
||||
if self.precision == "fp16":
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = self.logits_processor(
|
||||
self.input_ids, next_token_logits
|
||||
)
|
||||
next_token_scores = self.logits_warper(
|
||||
self.input_ids, next_token_scores
|
||||
)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if self.eos_token_id is not None:
|
||||
if self.pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = (
|
||||
next_token * self.unfinished_sequences
|
||||
+ self.pad_token_id * (1 - self.unfinished_sequences)
|
||||
)
|
||||
|
||||
self.input_ids = torch.cat(
|
||||
[self.input_ids, next_token[:, None]], dim=-1
|
||||
)
|
||||
|
||||
self.model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in self.model_kwargs:
|
||||
attention_mask = self.model_kwargs["attention_mask"]
|
||||
self.model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
self.input_ids = self.input_ids[:, 1:]
|
||||
self.model_kwargs["attention_mask"] = self.model_kwargs[
|
||||
"attention_mask"
|
||||
][:, 1:]
|
||||
|
||||
return next_token
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
falcon_mlir_path = (
|
||||
Path(
|
||||
"falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ args.precision
|
||||
+ ".mlir"
|
||||
)
|
||||
if args.falcon_mlir_path is None
|
||||
else Path(args.falcon_mlir_path)
|
||||
)
|
||||
falcon_vmfb_path = (
|
||||
Path(
|
||||
"falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ args.precision
|
||||
+ "_"
|
||||
+ args.device
|
||||
+ ".vmfb"
|
||||
)
|
||||
if args.falcon_vmfb_path is None
|
||||
else Path(args.falcon_vmfb_path)
|
||||
)
|
||||
|
||||
falcon = Falcon(
|
||||
"falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path="tiiuae/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
falcon_vmfb_path=falcon_vmfb_path,
|
||||
)
|
||||
|
||||
import gc
|
||||
|
||||
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
|
||||
continue_execution = True
|
||||
|
||||
print("\n-----\nScript executing for the following config: \n")
|
||||
print("Falcon Model: ", falcon.model_name)
|
||||
print("Precision: ", args.precision)
|
||||
print("Device: ", args.device)
|
||||
|
||||
while continue_execution:
|
||||
use_default_prompt = input(
|
||||
"\nDo you wish to use the default prompt text? Y/N ?: "
|
||||
)
|
||||
if use_default_prompt in ["Y", "y"]:
|
||||
prompt = default_prompt_text
|
||||
else:
|
||||
prompt = input("Please enter the prompt text: ")
|
||||
print("\nPrompt Text: ", prompt)
|
||||
|
||||
res_str = falcon.generate(prompt)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
print(
|
||||
"\n\n-----\nHere's the complete formatted result: \n\n",
|
||||
res_str,
|
||||
)
|
||||
continue_execution = input(
|
||||
"\nDo you wish to run script one more time? Y/N ?: "
|
||||
)
|
||||
continue_execution = (
|
||||
True if continue_execution in ["Y", "y"] else False
|
||||
)
|
||||
185
apps/language_models/src/pipelines/stablelm_pipeline.py
Normal file
185
apps/language_models/src/pipelines/stablelm_pipeline.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.src.model_wrappers.stablelm_model import (
|
||||
StableLMModel,
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class SharkStableLM(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def shouldStop(self, tokens):
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_src_model(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, torch_dtype=torch.float32
|
||||
)
|
||||
return model
|
||||
|
||||
def get_model_inputs(self):
|
||||
input_ids = torch.randint(3, (1, self.max_sequence_len))
|
||||
attention_mask = torch.randint(3, (1, self.max_sequence_len))
|
||||
return input_ids, attention_mask
|
||||
|
||||
def compile(self):
|
||||
tmp_model_name = (
|
||||
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
|
||||
)
|
||||
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
model_vmfb_name = None
|
||||
vmfb_path = (
|
||||
Path(tmp_model_name + f"_{self.device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
shark_module = get_vmfb_from_path(
|
||||
vmfb_path, self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
if shark_module is not None:
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(tmp_model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
model = StableLMModel(self.get_src_model())
|
||||
model_inputs = self.get_model_inputs()
|
||||
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*model_inputs],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(tmp_model_name + ".mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir")
|
||||
f_.close()
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
|
||||
def get_tokenizer(self):
|
||||
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
def generate(self, prompt):
|
||||
words_list = []
|
||||
for i in range(self.max_num_tokens):
|
||||
params = {
|
||||
"new_text": prompt,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params)
|
||||
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
|
||||
if stop_generation:
|
||||
break
|
||||
|
||||
print(detok, end="", flush=True) # this is for CLI and DEBUG
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
prompt = prompt + detok
|
||||
return words_list
|
||||
|
||||
def generate_new_token(self, params):
|
||||
new_text = params["new_text"]
|
||||
model_inputs = self.tokenizer(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=self.max_sequence_len,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
output = self.shark_model(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if self.shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
|
||||
detok = self.tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
|
||||
# Initialize a StopOnTokens object
|
||||
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
569
apps/language_models/src/pipelines/vicuna_pipeline.py
Normal file
569
apps/language_models/src/pipelines/vicuna_pipeline.py
Normal file
@@ -0,0 +1,569 @@
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx, get_f16_inputs
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
|
||||
|
||||
class Vicuna(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
first_vicuna_mlir_path=Path("first_vicuna.mlir"),
|
||||
second_vicuna_mlir_path=Path("second_vicuna.mlir"),
|
||||
first_vicuna_vmfb_path=Path("first_vicuna.vmfb"),
|
||||
second_vicuna_vmfb_path=Path("second_vicuna.vmfb"),
|
||||
load_mlir_from_shark_tank=True,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
|
||||
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
|
||||
self.first_vicuna_mlir_path = first_vicuna_mlir_path
|
||||
self.second_vicuna_mlir_path = second_vicuna_mlir_path
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
return vicuna_model
|
||||
|
||||
def compile_first_vicuna(self):
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.first_vicuna_vmfb_path, self.device, "tm_tensor"
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
# Compilation path needs some more work before it is functional
|
||||
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with"
|
||||
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if self.first_vicuna_mlir_path.exists():
|
||||
with open(self.first_vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
if self.load_mlir_from_shark_tank:
|
||||
if self.precision in ["fp32", "fp16"]:
|
||||
# download MLIR from shark_tank for fp32/fp16
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
|
||||
self.first_vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.first_vicuna_mlir_path.exists():
|
||||
with open(self.first_vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Only fp32 and fp16 mlir added to tank, generating {self.precision} mlir on device."
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = self.tokenizer(
|
||||
compilation_prompt
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(
|
||||
compilation_input_ids
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(self.hf_model_path)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
firstVicunaCompileInput,
|
||||
is_f16=self.precision == "fp16",
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del model
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
|
||||
firstVicunaCompileInput = list(firstVicunaCompileInput)
|
||||
firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like(
|
||||
firstVicunaCompileInput[0], dynamic_axes=[1]
|
||||
)
|
||||
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dim)", line
|
||||
)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
return line
|
||||
|
||||
module = str(module)
|
||||
new_lines = []
|
||||
|
||||
print(f"[DEBUG] rewriting torch_mlir file")
|
||||
for line in module.splitlines():
|
||||
line = remove_constant_dim(line)
|
||||
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
|
||||
new_lines.append(
|
||||
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
|
||||
)
|
||||
if (
|
||||
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
|
||||
in line
|
||||
):
|
||||
continue
|
||||
|
||||
new_lines.append(line)
|
||||
|
||||
module = "\n".join(new_lines)
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
del new_lines
|
||||
module = module.encode("UTF-8")
|
||||
module = BytesIO(module)
|
||||
bytecode = module.read()
|
||||
del module
|
||||
|
||||
print(f"[DEBUG] writing mlir to file")
|
||||
f_ = open(self.first_vicuna_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.first_vicuna_vmfb_path.parent.absolute(),
|
||||
self.first_vicuna_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
print("Saved first vic vmfb at vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile_second_vicuna(self):
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.second_vicuna_vmfb_path, self.device, "tm_tensor"
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
# Compilation path needs some more work before it is functional
|
||||
print(
|
||||
f"[DEBUG] mlir path {self.second_vicuna_mlir_path} {'exists' if self.second_vicuna_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if self.second_vicuna_mlir_path.exists():
|
||||
with open(self.second_vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
if self.load_mlir_from_shark_tank:
|
||||
if self.precision in ["fp32", "fp16"]:
|
||||
# download MLIR from shark_tank for fp32/fp16
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
|
||||
self.second_vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.second_vicuna_mlir_path.exists():
|
||||
with open(self.second_vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Only fp32 mlir added to tank, generating mlir on device."
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
|
||||
pkv = tuple(
|
||||
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
|
||||
for _ in range(64)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
model = SecondVicuna(self.hf_model_path)
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
secondVicunaCompileInput,
|
||||
is_f16=self.precision == "fp16",
|
||||
f16_input_mask=[False] + [True] * 64,
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
if self.precision == "fp16":
|
||||
secondVicunaCompileInput = get_f16_inputs(
|
||||
secondVicunaCompileInput,
|
||||
True,
|
||||
f16_input_mask=[False] + [True] * 64,
|
||||
)
|
||||
secondVicunaCompileInput = list(secondVicunaCompileInput)
|
||||
for i in range(len(secondVicunaCompileInput)):
|
||||
if i != 0:
|
||||
secondVicunaCompileInput[
|
||||
i
|
||||
] = torch_mlir.TensorPlaceholder.like(
|
||||
secondVicunaCompileInput[i], dynamic_axes=[2]
|
||||
)
|
||||
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "c19_i64" in line:
|
||||
line = re.sub("c19_i64", "dim_i64", line)
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dim)", line
|
||||
)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
if "20x" in line:
|
||||
line = re.sub("20x", "?x", line)
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dimp1)", line
|
||||
)
|
||||
if " 20," in line:
|
||||
line = re.sub(" 20,", " %dimp1,", line)
|
||||
return line
|
||||
|
||||
module_str = str(module)
|
||||
new_lines = []
|
||||
|
||||
for line in module_str.splitlines():
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append(
|
||||
f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
)
|
||||
continue
|
||||
if "%c2 = arith.constant 2 : index" in line:
|
||||
continue
|
||||
if "%c20_i64 = arith.constant 20 : i64" in line:
|
||||
new_lines.append("%c1_i64 = arith.constant 1 : i64")
|
||||
new_lines.append(
|
||||
"%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
|
||||
)
|
||||
continue
|
||||
line = remove_constant_dim(line)
|
||||
new_lines.append(line)
|
||||
|
||||
module_str = "\n".join(new_lines)
|
||||
bytecode = module_str.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
f_ = open(self.second_vicuna_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
|
||||
path = shark_module.save_module(
|
||||
self.second_vicuna_vmfb_path.parent.absolute(),
|
||||
self.second_vicuna_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
shark_module.load_module(self.second_vicuna_vmfb_path)
|
||||
|
||||
# self.shark_module = shark_module
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile(self):
|
||||
# Cannot load both the models in the memory at once
|
||||
# due to memory constraints, hence on demand compilation
|
||||
# is being used until the space is enough for both models
|
||||
|
||||
# Testing : DO NOT Download Vmfbs if not found. Modify later
|
||||
# download vmfbs for A100
|
||||
if (
|
||||
not self.first_vicuna_vmfb_path.exists()
|
||||
and self.device in ["cuda", "cpu"]
|
||||
and self.precision in ["fp32", "fp16"]
|
||||
):
|
||||
# combinations that are still in the works
|
||||
if not (self.device == "cuda" and self.precision == "fp16"):
|
||||
# Will generate vmfb on device
|
||||
pass
|
||||
else:
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
|
||||
self.first_vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
else:
|
||||
# get first vic
|
||||
# TODO: Remove after testing to avoid memory overload
|
||||
# fvic_shark_model = self.compile_first_vicuna()
|
||||
pass
|
||||
if (
|
||||
not self.second_vicuna_vmfb_path.exists()
|
||||
and self.device in ["cuda", "cpu"]
|
||||
and self.precision in ["fp32", "fp16"]
|
||||
):
|
||||
# combinations that are still in the works
|
||||
if not (self.device == "cuda" and self.precision == "fp16"):
|
||||
# Will generate vmfb on device
|
||||
pass
|
||||
else:
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
|
||||
self.second_vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
else:
|
||||
# get second vic
|
||||
# TODO: Remove after testing to avoid memory overload
|
||||
# svic_shark_model = self.compile_second_vicuna()
|
||||
pass
|
||||
|
||||
return None
|
||||
# return tuple of shark_modules once mem is supported
|
||||
# return fvic_shark_model, svic_shark_model
|
||||
|
||||
def generate(self, prompt, cli=False):
|
||||
# TODO: refactor for cleaner integration
|
||||
import gc
|
||||
|
||||
res = []
|
||||
res_tokens = []
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"is_first": True,
|
||||
"fv": self.compile_first_vicuna(),
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["pkv"]
|
||||
detok = generated_token_op["detok"]
|
||||
|
||||
res.append(detok)
|
||||
res_tokens.append(token)
|
||||
if cli:
|
||||
print(f"Assistant: {detok}", end=" ", flush=True)
|
||||
|
||||
# Clear First Vic from Memory (main and cuda)
|
||||
del params
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
sec_vic = self.compile_second_vicuna()
|
||||
for _ in range(self.max_num_tokens - 2):
|
||||
params = {
|
||||
"prompt": None,
|
||||
"is_first": False,
|
||||
"logits": logits,
|
||||
"pkv": pkv,
|
||||
"sv": sec_vic,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["pkv"]
|
||||
detok = generated_token_op["detok"]
|
||||
|
||||
if token == 2:
|
||||
break
|
||||
res_tokens.append(token)
|
||||
if detok == "<0x0A>":
|
||||
res.append("\n")
|
||||
if cli:
|
||||
print("\n", end="", flush=True)
|
||||
else:
|
||||
res.append(detok)
|
||||
if cli:
|
||||
print(f"{detok}", end=" ", flush=True)
|
||||
del sec_vic, pkv, logits
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
for i in range(len(res_tokens)):
|
||||
if type(res_tokens[i]) != int:
|
||||
res_tokens[i] = int(res_tokens[i][0])
|
||||
|
||||
res_str = self.tokenizer.decode(res_tokens)
|
||||
# print(f"[DEBUG] final output : \n{res_str}")
|
||||
return res_str
|
||||
|
||||
def generate_new_token(self, params, debug=False):
|
||||
def forward_first(first_vic, prompt, cache_outputs=False):
|
||||
input_ids = self.tokenizer(prompt).input_ids
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
firstVicunaInput = (input_ids,)
|
||||
assert first_vic is not None
|
||||
output_first_vicuna = first_vic("forward", firstVicunaInput)
|
||||
output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:])
|
||||
logits_first_vicuna = torch.tensor(output_first_vicuna[0])
|
||||
if cache_outputs:
|
||||
torch.save(
|
||||
logits_first_vicuna, "logits_first_vicuna_tensor.pt"
|
||||
)
|
||||
torch.save(
|
||||
output_first_vicuna_tensor, "output_first_vicuna_tensor.pt"
|
||||
)
|
||||
token = torch.argmax(
|
||||
torch.tensor(logits_first_vicuna)[:, -1, :], dim=1
|
||||
)
|
||||
return token, logits_first_vicuna, output_first_vicuna_tensor
|
||||
|
||||
def forward_second(sec_vic, inputs=None, load_inputs=False):
|
||||
if inputs is not None:
|
||||
logits = inputs[0]
|
||||
pkv = inputs[1:]
|
||||
elif load_inputs:
|
||||
pkv = torch.load("output_first_vicuna_tensor.pt")
|
||||
pkv = tuple(torch.tensor(x) for x in pkv)
|
||||
logits = torch.load("logits_first_vicuna_tensor.pt")
|
||||
else:
|
||||
print(
|
||||
"Either inputs must be given, or load_inputs must be true"
|
||||
)
|
||||
return None
|
||||
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
|
||||
token = token.to(torch.int64).reshape([1, 1])
|
||||
secondVicunaInput = (token,) + tuple(pkv)
|
||||
|
||||
secondVicunaOutput = sec_vic("forward", secondVicunaInput)
|
||||
new_pkv = secondVicunaOutput[1:]
|
||||
new_logits = secondVicunaOutput[0]
|
||||
new_token = torch.argmax(torch.tensor(new_logits)[:, -1, :], dim=1)
|
||||
return new_token, new_logits, new_pkv
|
||||
|
||||
is_first = params["is_first"]
|
||||
|
||||
if is_first:
|
||||
prompt = params["prompt"]
|
||||
fv = params["fv"]
|
||||
token, logits, pkv = forward_first(
|
||||
fv, # self.shark_model[0],
|
||||
prompt=prompt,
|
||||
cache_outputs=False,
|
||||
)
|
||||
else:
|
||||
_logits = params["logits"]
|
||||
_pkv = params["pkv"]
|
||||
inputs = (_logits,) + tuple(_pkv)
|
||||
sv = params["sv"]
|
||||
token, logits, pkv = forward_second(
|
||||
sv, # self.shark_model[1],
|
||||
inputs=inputs,
|
||||
load_inputs=False,
|
||||
)
|
||||
|
||||
detok = self.tokenizer.decode(token)
|
||||
if debug:
|
||||
print(
|
||||
f"[DEBUG] is_first: {is_first} |"
|
||||
f" token : {token} | detok : {detok}"
|
||||
)
|
||||
ret_dict = {
|
||||
"token": token,
|
||||
"logits": logits,
|
||||
"pkv": pkv,
|
||||
"detok": detok,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
pass
|
||||
402
apps/language_models/src/pipelines/vicuna_sharded_pipeline.py
Normal file
402
apps/language_models/src/pipelines/vicuna_sharded_pipeline.py
Normal file
@@ -0,0 +1,402 @@
|
||||
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
FirstVicunaLayer,
|
||||
SecondVicunaLayer,
|
||||
CompiledFirstVicunaLayer,
|
||||
CompiledSecondVicunaLayer,
|
||||
ShardedVicunaModel,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from shark.shark_importer import import_with_fx
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from tqdm import tqdm
|
||||
from torch_mlir import TensorPlaceholder
|
||||
|
||||
|
||||
import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
|
||||
|
||||
class Vicuna(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile(device=device)
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, use_fast=False
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
return vicuna_model
|
||||
|
||||
def write_in_dynamic_inputs0(self, module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
def write_in_dynamic_inputs1(self, module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
if "dim_42 =" in line:
|
||||
continue
|
||||
if f"%c{dynamic_input_size}_i64 =" in line:
|
||||
new_lines.append(
|
||||
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
|
||||
)
|
||||
new_lines.append(
|
||||
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
|
||||
)
|
||||
continue
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dim_42)", line
|
||||
)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim_42\)",
|
||||
"tensor.empty(%dim_42, %dim_42)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
def compile_vicuna_layer(
|
||||
self,
|
||||
vicuna_layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0=None,
|
||||
past_key_value1=None,
|
||||
):
|
||||
if past_key_value0 is None and past_key_value1 is None:
|
||||
model_inputs = (hidden_states, attention_mask, position_ids)
|
||||
else:
|
||||
model_inputs = (
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
)
|
||||
mlir_bytecode = import_with_fx(
|
||||
vicuna_layer,
|
||||
model_inputs,
|
||||
is_f16=self.precision == "fp16",
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
return mlir_bytecode
|
||||
|
||||
def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True):
|
||||
mlirs, modules = [], []
|
||||
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
|
||||
if is_first:
|
||||
mlir_path = Path(f"{idx}_0.mlir")
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
else:
|
||||
mlir_path = Path(f"{idx}_1.mlir")
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if vmfb_path.exists():
|
||||
continue
|
||||
if mlir_path.exists():
|
||||
# print(f"Found layer {idx} mlir")
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
inputs[0], dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
inputs[1], dynamic_axes=[3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
inputs[2], dynamic_axes=[1]
|
||||
)
|
||||
if not is_first:
|
||||
pkv0_placeholder = TensorPlaceholder.like(
|
||||
inputs[3], dynamic_axes=[2]
|
||||
)
|
||||
pkv1_placeholder = TensorPlaceholder.like(
|
||||
inputs[4], dynamic_axes=[2]
|
||||
)
|
||||
print(f"Compiling layer {idx} mlir")
|
||||
if is_first:
|
||||
ts_g = self.compile_vicuna_layer(
|
||||
layer, inputs[0], inputs[1], inputs[2]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
hidden_states_placeholder,
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
else:
|
||||
ts_g = self.compile_vicuna_layer(
|
||||
layer,
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
inputs[3],
|
||||
inputs[4],
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
inputs[0],
|
||||
attention_mask_placeholder,
|
||||
inputs[2],
|
||||
pkv0_placeholder,
|
||||
pkv1_placeholder,
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# bytecode_stream = BytesIO()
|
||||
# module.operation.write_bytecode(bytecode_stream)
|
||||
# bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if is_first:
|
||||
module = self.write_in_dynamic_inputs0(str(module), 137)
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
else:
|
||||
module = self.write_in_dynamic_inputs1(str(module), 138)
|
||||
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
mlirs.append(bytecode)
|
||||
|
||||
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
|
||||
if is_first:
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None,
|
||||
device=device,
|
||||
device_idx=idx % 1,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx],
|
||||
device=device,
|
||||
device_idx=idx % 1,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_0",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
else:
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None,
|
||||
device=device,
|
||||
device_idx=idx % 1,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx],
|
||||
device=device,
|
||||
device_idx=idx % 1,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_1",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
|
||||
return mlirs, modules
|
||||
|
||||
def get_sharded_model(self, device="cpu"):
|
||||
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
|
||||
# please don't change it
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
vicuna_model = self.get_src_model()
|
||||
placeholder_input0 = (
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
|
||||
)
|
||||
|
||||
placeholder_input1 = (
|
||||
torch.zeros([1, 1, 4096]),
|
||||
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
|
||||
torch.zeros([1, 1], dtype=torch.int64),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
|
||||
layers0 = [
|
||||
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
_, modules0 = self.compile_to_vmfb(
|
||||
placeholder_input0,
|
||||
layers0,
|
||||
is_first=True,
|
||||
device=device,
|
||||
)
|
||||
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
|
||||
|
||||
layers1 = [
|
||||
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
_, modules1 = self.compile_to_vmfb(
|
||||
placeholder_input1, layers1, is_first=False, device=device
|
||||
)
|
||||
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
|
||||
|
||||
sharded_model = ShardedVicunaModel(
|
||||
vicuna_model, shark_layers0, shark_layers1
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
def compile(self, device="cpu"):
|
||||
return self.get_sharded_model(device=device)
|
||||
|
||||
def generate(self, prompt, cli=False):
|
||||
# TODO: refactor for cleaner integration
|
||||
|
||||
tokens_generated = []
|
||||
_past_key_values = None
|
||||
_token = None
|
||||
detoks_generated = []
|
||||
for iteration in range(self.max_num_tokens):
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"is_first": iteration == 0,
|
||||
"token": _token,
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
|
||||
_token = generated_token_op["token"]
|
||||
_past_key_values = generated_token_op["past_key_values"]
|
||||
_detok = generated_token_op["detok"]
|
||||
|
||||
if _token == 2:
|
||||
break
|
||||
detoks_generated.append(_detok)
|
||||
tokens_generated.append(_token)
|
||||
|
||||
for i in range(len(tokens_generated)):
|
||||
if type(tokens_generated[i]) != int:
|
||||
tokens_generated[i] = int(tokens_generated[i][0])
|
||||
result_output = self.tokenizer.decode(tokens_generated)
|
||||
return result_output
|
||||
|
||||
def generate_new_token(self, params):
|
||||
is_first = params["is_first"]
|
||||
if is_first:
|
||||
prompt = params["prompt"]
|
||||
input_ids = self.tokenizer(prompt).input_ids
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
output = self.shark_model.forward(input_ids, is_first=is_first)
|
||||
else:
|
||||
token = params["token"]
|
||||
past_key_values = params["past_key_values"]
|
||||
input_ids = [token]
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
output = self.shark_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=is_first
|
||||
)
|
||||
|
||||
_logits = output["logits"]
|
||||
_past_key_values = output["past_key_values"]
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
_detok = self.tokenizer.decode(_token)
|
||||
|
||||
ret_dict = {
|
||||
"token": _token,
|
||||
"detok": _detok,
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
|
||||
print(f" token : {_token} | detok : {_detok}")
|
||||
|
||||
return ret_dict
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
pass
|
||||
25
apps/language_models/utils.py
Normal file
25
apps/language_models/utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# expects a Path / str as arg
|
||||
# returns None if path not found or SharkInference module
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
if not isinstance(vmfb_path, Path):
|
||||
vmfb_path = Path(vmfb_path)
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
if not vmfb_path.exists():
|
||||
return None
|
||||
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect=mlir_dialect
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
return shark_module
|
||||
@@ -10,7 +10,7 @@ Vulkan AMD:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
|
||||
# use –iree-input-type=mhlo for tf models
|
||||
# use –iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
|
||||
|
||||
CUDA NVIDIA:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
@@ -1,4 +1 @@
|
||||
from apps.stable_diffusion.scripts.inpaint import inpaint_inf
|
||||
from apps.stable_diffusion.scripts.outpaint import outpaint_inf
|
||||
from apps.stable_diffusion.scripts.upscaler import upscaler_inf
|
||||
from apps.stable_diffusion.scripts.train_lora_word import lora_train
|
||||
|
||||
@@ -14,189 +14,6 @@ from apps.stable_diffusion.src import (
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def inpaint_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image_dict,
|
||||
height: int,
|
||||
width: int,
|
||||
inpaint_full_res: bool,
|
||||
inpaint_full_res_padding: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.img_path = "not none"
|
||||
args.mask_path = "not none"
|
||||
args.ondemand = ondemand
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.custom_vae = custom_vae
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
"inpaint",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.use_tuned = init_use_tuned
|
||||
args.import_mlir = init_import_mlir
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
InpaintPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
custom_vae=args.custom_vae,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
image = image_dict["image"]
|
||||
mask_image = image_dict["mask"]
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
mask_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
inpaint_full_res,
|
||||
inpaint_full_res_padding,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
@@ -11,199 +11,6 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def outpaint_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image,
|
||||
pixels: int,
|
||||
mask_blur: int,
|
||||
directions: list,
|
||||
noise_q: float,
|
||||
color_variation: float,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.img_path = "not none"
|
||||
args.ondemand = ondemand
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.custom_vae = custom_vae
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
"outpaint",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.use_tuned = init_use_tuned
|
||||
args.import_mlir = init_import_mlir
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
OutpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
|
||||
left = True if "left" in directions else False
|
||||
right = True if "right" in directions else False
|
||||
top = True if "up" in directions else False
|
||||
bottom = True if "down" in directions else False
|
||||
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
init_image,
|
||||
pixels,
|
||||
mask_blur,
|
||||
left,
|
||||
right,
|
||||
top,
|
||||
bottom,
|
||||
noise_q,
|
||||
color_variation,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -17,6 +17,10 @@ from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
|
||||
|
||||
def load_mlir_module():
|
||||
if "upscaler" in args.hf_model_id:
|
||||
is_upscaler = True
|
||||
else:
|
||||
is_upscaler = False
|
||||
sd_model = SharkifyStableDiffusionModel(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
@@ -27,6 +31,7 @@ def load_mlir_module():
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
is_upscaler=is_upscaler,
|
||||
use_tuned=False,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
return_mlir=True,
|
||||
@@ -54,12 +59,19 @@ def main():
|
||||
# Get device and device specific arguments
|
||||
device, device_spec_args = get_device_args()
|
||||
device_spec = ""
|
||||
vulkan_target_triple = ""
|
||||
if device_spec_args:
|
||||
device_spec = device_spec_args[-1].split("=")[-1].strip()
|
||||
if device == "vulkan":
|
||||
vulkan_target_triple = device_spec
|
||||
device_spec = device_spec.split("-")[0]
|
||||
|
||||
# Add winograd annotation for vulkan device
|
||||
use_winograd = (
|
||||
True
|
||||
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
|
||||
else False
|
||||
)
|
||||
winograd_config = (
|
||||
load_winograd_configs()
|
||||
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
|
||||
@@ -71,19 +83,23 @@ def main():
|
||||
input_contents=mlir_module,
|
||||
config_path=winograd_config,
|
||||
search_op="conv",
|
||||
winograd=True,
|
||||
winograd=use_winograd,
|
||||
)
|
||||
|
||||
# Dump model dispatches
|
||||
if device == "vulkan" and device_spec == "rdna3":
|
||||
device = "vulkan/RX 7900"
|
||||
generates_dir = Path.home() / "tmp"
|
||||
if not os.path.exists(generates_dir):
|
||||
os.makedirs(generates_dir)
|
||||
dump_mlir = generates_dir / "temp.mlir"
|
||||
dispatch_dir = generates_dir / f"{model_name}_{device_spec}_dispatches"
|
||||
export_module_to_mlir_file(input_module, dump_mlir)
|
||||
dump_dispatches(dump_mlir, device, dispatch_dir, False)
|
||||
dump_dispatches(
|
||||
dump_mlir,
|
||||
device,
|
||||
dispatch_dir,
|
||||
vulkan_target_triple,
|
||||
use_winograd=use_winograd,
|
||||
)
|
||||
|
||||
# Tune each dispatch
|
||||
dtype = "f16" if args.precision == "fp16" else "f32"
|
||||
@@ -106,6 +122,7 @@ def main():
|
||||
batch_size=1,
|
||||
config_filename=config_filename,
|
||||
use_dispatch=True,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
)
|
||||
tuner.tune()
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ def main():
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
@@ -13,195 +13,6 @@ from apps.stable_diffusion.src import (
|
||||
)
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def upscaler_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
noise_level: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.seed = seed
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.ondemand = ondemand
|
||||
|
||||
if init_image is None:
|
||||
return None, "An Initial Image is required"
|
||||
image = init_image.convert("RGB").resize((height, width))
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.custom_vae = custom_vae
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
args.height = 128
|
||||
args.width = 128
|
||||
new_config_obj = Config(
|
||||
"upscaler",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
args.height,
|
||||
args.width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.use_tuned = init_use_tuned
|
||||
args.import_mlir = init_import_mlir
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
UpscalerPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
|
||||
"DDPM"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
extra_info = {"NOISE LEVEL": noise_level}
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
low_res_img = image
|
||||
high_res_img = Image.new("RGB", (height * 4, width * 4))
|
||||
|
||||
for i in range(0, width, 128):
|
||||
for j in range(0, height, 128):
|
||||
box = (j, i, j + 128, i + 128)
|
||||
upscaled_image = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
low_res_img.crop(box),
|
||||
batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
steps,
|
||||
noise_level,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
|
||||
|
||||
save_output_img(high_res_img, img_seed, extra_info)
|
||||
generated_imgs.append(high_res_img)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
yield generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
@@ -19,6 +19,7 @@ datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += copy_metadata('Pillow')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
@@ -29,6 +30,9 @@ datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('shark')
|
||||
datas += collect_data_files('tkinter')
|
||||
datas += collect_data_files('webview')
|
||||
datas += collect_data_files('sentencepiece')
|
||||
datas += [
|
||||
( 'src/utils/resources/prompts.json', 'resources' ),
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
@@ -44,6 +48,7 @@ block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['web/index.py'],
|
||||
|
||||
@@ -42,6 +42,7 @@ block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/main.py'],
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
|
||||
from transformers import CLIPTextModel
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import safetensors.torch
|
||||
import traceback
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
@@ -12,6 +14,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
base_models,
|
||||
args,
|
||||
preprocessCKPT,
|
||||
convert_original_vae,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_stem,
|
||||
@@ -91,10 +94,19 @@ class SharkifyStableDiffusionModel:
|
||||
self.custom_weights = custom_weights
|
||||
self.use_quantize = use_quantize
|
||||
if custom_weights != "":
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
if "civitai" in custom_weights:
|
||||
weights_id = custom_weights.split("/")[-1]
|
||||
# TODO: use model name and identify file type by civitai rest api
|
||||
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
|
||||
if not os.path.isfile(weights_path):
|
||||
subprocess.run(["wget", custom_weights, "-O", weights_path])
|
||||
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
|
||||
self.custom_weights = weights_path
|
||||
else:
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
# TODO: remove the following line when stable-diffusion-2-1 works
|
||||
if self.model_id == "stabilityai/stable-diffusion-2-1":
|
||||
@@ -151,7 +163,7 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_extended_name_for_all_model(self):
|
||||
model_name = {}
|
||||
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
|
||||
sub_model_list = ["clip", "unet", "unet512", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
sub_model = model
|
||||
@@ -161,6 +173,8 @@ class SharkifyStableDiffusionModel:
|
||||
model_config = model_config + get_path_stem(self.custom_vae)
|
||||
if self.base_vae:
|
||||
sub_model = "base_vae"
|
||||
if "stencil_adaptor" == model and self.use_stencil is not None:
|
||||
model_config = model_config + get_path_stem(self.use_stencil)
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
index += 1
|
||||
return model_name
|
||||
@@ -265,7 +279,7 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
|
||||
if self.debug:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
@@ -401,7 +415,7 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_cnet, cnet_mlir
|
||||
|
||||
def get_unet(self):
|
||||
def get_unet(self, use_large=False):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
|
||||
super().__init__()
|
||||
@@ -412,7 +426,7 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.in_channels = self.unet.config.in_channels
|
||||
self.train(False)
|
||||
if(args.attention_slicing is not None and args.attention_slicing != "none"):
|
||||
if(args.attention_slicing.isdigit()):
|
||||
@@ -438,17 +452,27 @@ class SharkifyStableDiffusionModel:
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
if(use_large):
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3])
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet512"])
|
||||
else:
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
||||
input_mask = [True, True, True, False]
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
model_name = "unet512" if use_large else "unet"
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["unet"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
@@ -457,13 +481,13 @@ class SharkifyStableDiffusionModel:
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="unet",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_unet, unet_mlir
|
||||
|
||||
def get_unet_upscaler(self):
|
||||
def get_unet_upscaler(self, use_large=False):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
@@ -488,6 +512,13 @@ class SharkifyStableDiffusionModel:
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
if(use_large):
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3])
|
||||
input_mask = [True, True, True, False]
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
@@ -558,19 +589,23 @@ class SharkifyStableDiffusionModel:
|
||||
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
|
||||
if "state_dict" in vae_checkpoint:
|
||||
vae_checkpoint = vae_checkpoint["state_dict"]
|
||||
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict
|
||||
|
||||
def compile_unet_variants(self, model):
|
||||
try:
|
||||
vae_checkpoint = convert_original_vae(vae_checkpoint)
|
||||
finally:
|
||||
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict
|
||||
|
||||
def compile_unet_variants(self, model, use_large=False):
|
||||
if model == "unet":
|
||||
if self.is_upscaler:
|
||||
return self.get_unet_upscaler()
|
||||
return self.get_unet_upscaler(use_large=use_large)
|
||||
# TODO: Plug the experimental "int8" support at right place.
|
||||
elif self.use_quantize == "int8":
|
||||
from apps.stable_diffusion.src.models.opt_params import get_unet
|
||||
return get_unet()
|
||||
else:
|
||||
return self.get_unet()
|
||||
return self.get_unet(use_large=use_large)
|
||||
else:
|
||||
return self.get_controlled_unet()
|
||||
|
||||
@@ -598,7 +633,7 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def unet(self):
|
||||
def unet(self, use_large=False):
|
||||
try:
|
||||
model = "stencil_unet" if self.use_stencil is not None else "unet"
|
||||
compiled_unet = None
|
||||
@@ -606,14 +641,14 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
if self.base_model_id != "":
|
||||
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
|
||||
else:
|
||||
for model_id in unet_inputs:
|
||||
self.base_model_id = model_id
|
||||
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
|
||||
|
||||
try:
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Retrying with a different base model configuration")
|
||||
|
||||
@@ -81,6 +81,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -112,7 +113,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
|
||||
@@ -20,6 +20,8 @@ from diffusers import (
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_IDLE,
|
||||
SD_STATE_CANCEL,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
@@ -84,6 +86,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.low_res_scheduler = low_res_scheduler
|
||||
self.status = SD_STATE_IDLE
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
accepts_eta = "eta" in set(
|
||||
@@ -164,6 +167,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.status = SD_STATE_IDLE
|
||||
self.load_unet()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
@@ -210,6 +214,9 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.status == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
|
||||
@@ -57,6 +57,7 @@ class StableDiffusionPipeline:
|
||||
self.vae = None
|
||||
self.text_encoder = None
|
||||
self.unet = None
|
||||
self.unet_512 = None
|
||||
self.model_max_length = 77
|
||||
self.scheduler = scheduler
|
||||
# TODO: Implement using logging python utility.
|
||||
@@ -87,7 +88,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.text_encoder = get_clip()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.text_encoder = self.sd_model.clip()
|
||||
|
||||
@@ -104,7 +106,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.unet = get_unet()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.unet = self.sd_model.unet()
|
||||
|
||||
@@ -112,6 +115,24 @@ class StableDiffusionPipeline:
|
||||
del self.unet
|
||||
self.unet = None
|
||||
|
||||
def load_unet_512(self):
|
||||
if self.unet_512 is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.unet_512 = self.sd_model.unet(use_large=True)
|
||||
else:
|
||||
try:
|
||||
self.unet_512 = get_unet(use_large=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.unet_512 = self.sd_model.unet(use_large=True)
|
||||
|
||||
def unload_unet_512(self):
|
||||
del self.unet_512
|
||||
self.unet_512 = None
|
||||
|
||||
def load_vae(self):
|
||||
if self.vae is not None:
|
||||
return
|
||||
@@ -121,7 +142,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.vae = get_vae()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae = self.sd_model.vae()
|
||||
|
||||
@@ -200,7 +222,10 @@ class StableDiffusionPipeline:
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.load_unet()
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
@@ -219,16 +244,28 @@ class StableDiffusionPipeline:
|
||||
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
@@ -251,6 +288,7 @@ class StableDiffusionPipeline:
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -409,6 +447,11 @@ class StableDiffusionPipeline:
|
||||
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
if text_embeddings.shape[1] > model_max_length:
|
||||
pad = (0, 0) * (len(text_embeddings.shape) - 2)
|
||||
pad = pad + (0, 512 - text_embeddings.shape[1])
|
||||
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
|
||||
|
||||
# SHARK: Report clip inference time
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
|
||||
@@ -40,6 +40,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
def compile(self):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
BATCH_SIZE = args.batch_size
|
||||
device = args.device.split(":", 1)[0].strip()
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
@@ -92,7 +93,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
self.scaling_model, _ = compile_through_fx(
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
|
||||
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
@@ -101,7 +102,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
self.step_model, _ = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
|
||||
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
@@ -24,14 +24,18 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
get_available_devices,
|
||||
get_opt_flags,
|
||||
preprocessCKPT,
|
||||
convert_original_vae,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
sanitize_seed,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
get_generation_text_info,
|
||||
update_lora_weight,
|
||||
resize_stencil,
|
||||
_compile_module,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"stablediffusion/untuned":"gs://shark_tank/nightly"
|
||||
},
|
||||
{
|
||||
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
|
||||
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
|
||||
@@ -45,12 +45,12 @@
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,6 +70,8 @@ def load_winograd_configs():
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
config_name = f"{args.annotation_model}_winograd_{device}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
if not os.path.exists(WORKDIR):
|
||||
os.mkdir(WORKDIR)
|
||||
winograd_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading Winograd config file from ", winograd_config_dir)
|
||||
download_public_file(full_gs_url, winograd_config_dir, True)
|
||||
@@ -114,7 +116,7 @@ def load_lower_configs(base_model_id=None):
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
|
||||
else:
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
if not spec or spec in ["sm_80"]:
|
||||
if (
|
||||
version in ["v2_1", "v2_1base"]
|
||||
and args.height == 768
|
||||
@@ -123,6 +125,15 @@ def load_lower_configs(base_model_id=None):
|
||||
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
elif spec in ["rdna3"] and version in [
|
||||
"v2_1",
|
||||
"v2_1base",
|
||||
"v1_4",
|
||||
"v1_5",
|
||||
]:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.max_length}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
|
||||
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
|
||||
|
||||
|
||||
@@ -108,6 +108,13 @@ p.add_argument(
|
||||
help="max length of the tokenizer output, options are 64 and 77.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_embeddings_multiples",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The max multiple length of prompt embeddings compared to the max output length of text encoder.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
@@ -493,7 +500,13 @@ p.add_argument(
|
||||
default="",
|
||||
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
|
||||
)
|
||||
|
||||
# TODO: replace API flag when these can be run together
|
||||
p.add_argument(
|
||||
"--ui",
|
||||
type=str,
|
||||
default="app" if os.name == "nt" else "web",
|
||||
help="one of: [api, app, web]",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--share",
|
||||
@@ -515,6 +528,22 @@ p.add_argument(
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for enabling rest API",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for removing the output gallery tab, and avoid exposing images under --output_dir in the UI",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery_followlinks",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for whether the output gallery tab in the UI should follow symlinks when listing subdirectorys under --output_dir",
|
||||
)
|
||||
|
||||
|
||||
##############################################################################
|
||||
### SD model auto-annotation flags
|
||||
##############################################################################
|
||||
|
||||
@@ -126,14 +126,14 @@ def controlnet_hint_conversion(
|
||||
|
||||
|
||||
stencil_to_model_id_map = {
|
||||
"canny": "lllyasviel/sd-controlnet-canny",
|
||||
"depth": "lllyasviel/sd-controlnet-depth",
|
||||
"canny": "lllyasviel/control_v11p_sd15_canny",
|
||||
"depth": "lllyasviel/control_v11p_sd15_depth",
|
||||
"hed": "lllyasviel/sd-controlnet-hed",
|
||||
"mlsd": "lllyasviel/sd-controlnet-mlsd",
|
||||
"normal": "lllyasviel/sd-controlnet-normal",
|
||||
"openpose": "lllyasviel/sd-controlnet-openpose",
|
||||
"scribble": "lllyasviel/sd-controlnet-scribble",
|
||||
"seg": "lllyasviel/sd-controlnet-seg",
|
||||
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
|
||||
"normal": "lllyasviel/control_v11p_sd15_normalbae",
|
||||
"openpose": "lllyasviel/control_v11p_sd15_openpose",
|
||||
"scribble": "lllyasviel/control_v11p_sd15_scribble",
|
||||
"seg": "lllyasviel/control_v11p_sd15_seg",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,12 @@ from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
import sys
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
create_vae_diffusers_config,
|
||||
convert_ldm_vae_checkpoint,
|
||||
)
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def get_extended_name(model_name):
|
||||
@@ -42,6 +47,7 @@ def get_vmfb_path_name(model_name):
|
||||
def _load_vmfb(shark_module, vmfb_path, model, precision):
|
||||
model = "vae" if "base_vae" in model or "vae_encode" in model else model
|
||||
model = "unet" if "stencil" in model else model
|
||||
model = "unet" if "unet512" in model else model
|
||||
precision = "fp32" if "clip" in model else precision
|
||||
extra_args = get_opt_flags(model, precision)
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
@@ -78,7 +84,6 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
if "cuda" in args.device:
|
||||
@@ -111,6 +116,7 @@ def compile_through_fx(
|
||||
model_name=None,
|
||||
precision=None,
|
||||
return_mlir=False,
|
||||
device=None,
|
||||
):
|
||||
if not return_mlir and model_name is not None:
|
||||
vmfb_path = get_vmfb_path_name(extended_model_name)
|
||||
@@ -141,7 +147,10 @@ def compile_through_fx(
|
||||
if use_tuned:
|
||||
if "vae" in extended_model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
if "unet" in model_name.split("_")[0]:
|
||||
if (
|
||||
"unet" in model_name.split("_")[0]
|
||||
or "unet_512" in model_name.split("_")[0]
|
||||
):
|
||||
args.annotation_model = "unet"
|
||||
mlir_module = sd_model_annotation(
|
||||
mlir_module, extended_model_name, base_model_id
|
||||
@@ -149,7 +158,7 @@ def compile_through_fx(
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
device=args.device if device is None else device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
if generate_vmfb:
|
||||
@@ -289,13 +298,18 @@ def set_init_device_flags():
|
||||
if (
|
||||
args.precision != "fp16"
|
||||
or args.height not in [512, 768]
|
||||
or (args.height == 512 and args.width != 512)
|
||||
or (args.height == 768 and args.width != 768)
|
||||
or (args.height == 512 and args.width not in [512, 768])
|
||||
or (args.height == 768 and args.width not in [512, 768])
|
||||
or args.batch_size != 1
|
||||
or ("vulkan" not in args.device and "cuda" not in args.device)
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif (
|
||||
args.height != args.width and "rdna2" in args.iree_vulkan_target_triple
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif base_model_id not in [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
@@ -333,13 +347,25 @@ def set_init_device_flags():
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
]
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
or "rdna" not in args.iree_vulkan_target_triple
|
||||
)
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif "rdna2" in args.iree_vulkan_target_triple and (
|
||||
base_model_id
|
||||
not in [
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
if args.use_tuned:
|
||||
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
|
||||
print(
|
||||
f"Using tuned models for {base_model_id}(fp16) on device {args.device}."
|
||||
)
|
||||
else:
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
@@ -464,7 +490,7 @@ def get_path_stem(path):
|
||||
def get_path_to_diffusers_checkpoint(custom_weights):
|
||||
path = Path(custom_weights)
|
||||
diffusers_path = path.parent.absolute()
|
||||
diffusers_directory_name = path.stem
|
||||
diffusers_directory_name = os.path.join("diffusers", path.stem)
|
||||
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
|
||||
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
|
||||
path_to_diffusers = complete_path_to_diffusers.as_posix()
|
||||
@@ -503,6 +529,22 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
print("Loading complete")
|
||||
|
||||
|
||||
def convert_original_vae(vae_checkpoint):
|
||||
vae_state_dict = {}
|
||||
for key in list(vae_checkpoint.keys()):
|
||||
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
|
||||
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=512)
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
vae_state_dict, vae_config
|
||||
)
|
||||
return converted_vae_checkpoint
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
@@ -673,11 +715,20 @@ def clear_all():
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
return Path(
|
||||
args.output_dir if args.output_dir else Path.cwd(), "generated_imgs"
|
||||
)
|
||||
|
||||
|
||||
def get_generated_imgs_todays_subdir() -> str:
|
||||
return dt.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed, extra_info={}):
|
||||
output_path = args.output_dir if args.output_dir else Path.cwd()
|
||||
generated_imgs_path = Path(
|
||||
output_path, "generated_imgs", dt.now().strftime("%Y%m%d")
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
@@ -1,18 +1,45 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import transformers
|
||||
import shutil
|
||||
import PIL, transformers # ensures inclusion in pysintaller exe generation
|
||||
from apps.stable_diffusion.src import args, clear_all
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
# import before IREE to avoid MLIR library issues
|
||||
import torch_mlir
|
||||
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
|
||||
def launch_app(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
|
||||
window = Tk()
|
||||
|
||||
# getting screen width and height of display
|
||||
width = window.winfo_screenwidth()
|
||||
height = window.winfo_screenheight()
|
||||
webview.create_window(
|
||||
"SHARK AI Studio", url=address, width=width, height=height
|
||||
)
|
||||
webview.start(private_mode=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.api:
|
||||
from apps.stable_diffusion.web.ui import txt2img_inf, img2img_api
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
if args.api or "api" in args.ui.split(","):
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_api,
|
||||
img2img_api,
|
||||
upscaler_api,
|
||||
inpaint_api,
|
||||
)
|
||||
from fastapi import FastAPI, APIRouter
|
||||
import uvicorn
|
||||
|
||||
@@ -20,24 +47,31 @@ if __name__ == "__main__":
|
||||
global_obj._init()
|
||||
|
||||
app = FastAPI()
|
||||
app.add_api_route("/sdapi/v1/txt2img", txt2img_inf, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
|
||||
# app.add_api_route(
|
||||
# "/sdapi/v1/outpaint", outpaint_api, methods=["post"]
|
||||
# )
|
||||
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
|
||||
app.include_router(APIRouter())
|
||||
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
|
||||
sys.exit(0)
|
||||
|
||||
import gradio as gr
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
clear_gradio_tmp_imgs_folder,
|
||||
config_gradio_tmp_imgs_folder,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
clear_gradio_tmp_imgs_folder()
|
||||
# Create the custom model folder if it doesn't already exist
|
||||
dir = ["models", "vae", "lora"]
|
||||
for root in dir:
|
||||
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
|
||||
config_gradio_tmp_imgs_folder()
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
|
||||
create_custom_models_folders()
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
@@ -50,36 +84,69 @@ if __name__ == "__main__":
|
||||
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
txt2img_sendto_img2img,
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
img2img_sendto_inpaint,
|
||||
img2img_sendto_outpaint,
|
||||
img2img_sendto_upscaler,
|
||||
inpaint_web,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
inpaint_sendto_img2img,
|
||||
inpaint_sendto_outpaint,
|
||||
inpaint_sendto_upscaler,
|
||||
outpaint_web,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
outpaint_sendto_img2img,
|
||||
outpaint_sendto_inpaint,
|
||||
outpaint_sendto_upscaler,
|
||||
upscaler_web,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
lora_train_web,
|
||||
model_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
modelmanager_sendto_inpaint,
|
||||
modelmanager_sendto_outpaint,
|
||||
modelmanager_sendto_upscaler,
|
||||
stablelm_chat,
|
||||
outputgallery_web,
|
||||
outputgallery_tab_select,
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
outputgallery_sendto_upscaler,
|
||||
)
|
||||
|
||||
# init global sd pipeline and config
|
||||
@@ -95,6 +162,27 @@ if __name__ == "__main__":
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_modelmanager_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
"None",
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_outputgallery_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
|
||||
) as sd_web:
|
||||
@@ -109,11 +197,29 @@ if __name__ == "__main__":
|
||||
outpaint_web.render()
|
||||
with gr.TabItem(label="Upscaler", id=4):
|
||||
upscaler_web.render()
|
||||
|
||||
with gr.Tabs(visible=False) as experimental_tabs:
|
||||
with gr.TabItem(label="LoRA Training", id=5):
|
||||
with gr.TabItem(label="Model Manager", id=5):
|
||||
model_web.render()
|
||||
with gr.TabItem(label="Chat Bot(Experimental)", id=6):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
|
||||
lora_train_web.render()
|
||||
if args.output_gallery:
|
||||
with gr.TabItem(label="Output Gallery", id=8) as og_tab:
|
||||
outputgallery_web.render()
|
||||
|
||||
# extra output gallery configuration
|
||||
outputgallery_tab_select(og_tab.select)
|
||||
outputgallery_watch(
|
||||
[
|
||||
txt2img_status,
|
||||
img2img_status,
|
||||
inpaint_status,
|
||||
outpaint_status,
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
|
||||
# send to buttons
|
||||
register_button_click(
|
||||
txt2img_sendto_img2img,
|
||||
1,
|
||||
@@ -210,10 +316,77 @@ if __name__ == "__main__":
|
||||
[upscaler_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
if args.output_gallery:
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_txt2img,
|
||||
0,
|
||||
[outputgallery_filename],
|
||||
[txt2img_png_info_img, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_img2img,
|
||||
1,
|
||||
[outputgallery_filename],
|
||||
[img2img_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_inpaint,
|
||||
2,
|
||||
[outputgallery_filename],
|
||||
[inpaint_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_outpaint,
|
||||
3,
|
||||
[outputgallery_filename],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_upscaler,
|
||||
4,
|
||||
[outputgallery_filename],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_txt2img,
|
||||
0,
|
||||
[hf_models],
|
||||
[txt2img_custom_model, txt2img_hf_model_id, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_img2img,
|
||||
1,
|
||||
[hf_models],
|
||||
[img2img_custom_model, img2img_hf_model_id, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_inpaint,
|
||||
2,
|
||||
[hf_models],
|
||||
[inpaint_custom_model, inpaint_hf_model_id, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_outpaint,
|
||||
3,
|
||||
[hf_models],
|
||||
[outpaint_custom_model, outpaint_hf_model_id, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_upscaler,
|
||||
4,
|
||||
[hf_models],
|
||||
[upscaler_custom_model, upscaler_hf_model_id, tabs],
|
||||
)
|
||||
|
||||
sd_web.queue()
|
||||
if args.ui == "app":
|
||||
t = Process(
|
||||
target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
)
|
||||
t.start()
|
||||
sd_web.launch(
|
||||
share=args.share,
|
||||
inbrowser=True,
|
||||
inbrowser=args.ui == "web",
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.server_port,
|
||||
)
|
||||
|
||||
@@ -1,44 +1,88 @@
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
txt2img_inf,
|
||||
txt2img_api,
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
txt2img_sendto_img2img,
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_api,
|
||||
img2img_inf,
|
||||
img2img_api,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
img2img_sendto_inpaint,
|
||||
img2img_sendto_outpaint,
|
||||
img2img_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
inpaint_inf,
|
||||
inpaint_api,
|
||||
inpaint_web,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
inpaint_sendto_img2img,
|
||||
inpaint_sendto_outpaint,
|
||||
inpaint_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
outpaint_inf,
|
||||
outpaint_api,
|
||||
outpaint_web,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
outpaint_sendto_img2img,
|
||||
outpaint_sendto_inpaint,
|
||||
outpaint_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.upscaler_ui import (
|
||||
upscaler_inf,
|
||||
upscaler_api,
|
||||
upscaler_web,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.model_manager import (
|
||||
model_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
modelmanager_sendto_inpaint,
|
||||
modelmanager_sendto_outpaint,
|
||||
modelmanager_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
|
||||
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat
|
||||
from apps.stable_diffusion.web.ui.outputgallery_ui import (
|
||||
outputgallery_web,
|
||||
outputgallery_tab_select,
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
outputgallery_sendto_upscaler,
|
||||
)
|
||||
|
||||
@@ -173,7 +173,30 @@ footer {
|
||||
#gallery .thumbnail-item.thumbnail-lg {
|
||||
aspect-ratio: unset;
|
||||
max-height: calc(55vh - (2 * var(--spacing-lg)));
|
||||
min-height: 390px
|
||||
}
|
||||
@media (min-width: 1921px) {
|
||||
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
|
||||
#gallery .grid-wrap, #gallery .preview{
|
||||
min-height: calc(768px + 4px + var(--size-14));
|
||||
max-height: calc(768px + 4px + var(--size-14));
|
||||
}
|
||||
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
|
||||
#gallery .thumbnail-item.thumbnail-lg {
|
||||
max-height: 770px !important;
|
||||
}
|
||||
}
|
||||
/* Don't upscale when viewing in solo image mode */
|
||||
#gallery .preview img {
|
||||
object-fit: scale-down;
|
||||
}
|
||||
/* Navbar images in cover mode*/
|
||||
#gallery .preview .thumbnail-item img {
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
/* Limit the stable diffusion text output height */
|
||||
#std_output textarea {
|
||||
max-height: 215px;
|
||||
}
|
||||
|
||||
/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */
|
||||
@@ -207,3 +230,44 @@ footer {
|
||||
#top_logo .download {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* output gallery tab */
|
||||
.output_parameters_dataframe tbody td {
|
||||
font-size: small;
|
||||
line-height: var(--line-xs)
|
||||
}
|
||||
|
||||
#output_refresh_button {
|
||||
max-width: 30px;
|
||||
align-self: end;
|
||||
padding-bottom: 8px;
|
||||
}
|
||||
|
||||
.outputgallery_sendto {
|
||||
min-width: 7em !important;
|
||||
}
|
||||
|
||||
/* output gallery should take up most of the viewport height regardless of image size/number */
|
||||
#outputgallery_gallery .fixed-height {
|
||||
min-height: 89vh !important;
|
||||
}
|
||||
|
||||
/* don't stretch non-square images to be square, breaking their aspect ratio */
|
||||
#outputgallery_gallery .thumbnail-item.thumbnail-lg > img {
|
||||
object-fit: contain !important;
|
||||
}
|
||||
|
||||
/* centered logo for when there are no images */
|
||||
#top_logo.logo_centered {
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
#top_logo.logo_centered img{
|
||||
object-fit: scale-down;
|
||||
position: absolute;
|
||||
width: 80%;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
}
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -26,10 +24,14 @@ from apps.stable_diffusion.src import (
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
import numpy as np
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
@@ -42,7 +44,7 @@ init_import_mlir = args.import_mlir
|
||||
def img2img_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image,
|
||||
image_dict,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
@@ -85,25 +87,35 @@ def img2img_inf(
|
||||
args.img_path = "not none"
|
||||
args.ondemand = ondemand
|
||||
|
||||
if init_image is None:
|
||||
if image_dict is None:
|
||||
return None, "An Initial Image is required"
|
||||
image = init_image.convert("RGB")
|
||||
if use_stencil == "scribble":
|
||||
image = image_dict["mask"].convert("RGB")
|
||||
elif isinstance(image_dict, PIL.Image.Image):
|
||||
image = image_dict.convert("RGB")
|
||||
else:
|
||||
image = image_dict["image"].convert("RGB")
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.custom_vae = custom_vae
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
@@ -130,6 +142,7 @@ def img2img_inf(
|
||||
"img2img",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
@@ -247,11 +260,17 @@ def img2img_inf(
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed, extra_info)
|
||||
save_output_img(
|
||||
out_imgs[0],
|
||||
img_seed,
|
||||
extra_info,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
# yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Image-to-Image", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
@@ -287,7 +306,9 @@ def encode_pil_to_base64(images):
|
||||
def img2img_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(InputData)
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = img2img_inf(
|
||||
InputData["prompt"],
|
||||
@@ -303,17 +324,26 @@ def img2img_api(
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id="stabilityai/stable-diffusion-2-1-base",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
use_stencil="None",
|
||||
use_stencil=InputData["use_stencil"]
|
||||
if "use_stencil" in InputData.keys()
|
||||
else "None",
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = list(res)[0]
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
@@ -336,32 +366,30 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
img2img_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
img2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
lines=3,
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"Custom Vae Models",
|
||||
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.custom_vae)
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -379,7 +407,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
)
|
||||
|
||||
img2img_init_image = gr.Image(
|
||||
label="Input Image", type="pil"
|
||||
label="Input Image",
|
||||
source="upload",
|
||||
tool="sketch",
|
||||
type="pil",
|
||||
).style(height=300)
|
||||
|
||||
with gr.Accordion(label="Stencil Options", open=False):
|
||||
@@ -390,6 +421,57 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
value="None",
|
||||
choices=["None", "canny", "openpose", "scribble"],
|
||||
)
|
||||
|
||||
def show_canvas(choice):
|
||||
if choice == "scribble":
|
||||
return (
|
||||
gr.Slider.update(visible=True),
|
||||
gr.Slider.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
gr.Slider.update(visible=False),
|
||||
gr.Slider.update(visible=False),
|
||||
gr.Button.update(visible=False),
|
||||
)
|
||||
|
||||
def create_canvas(w, h):
|
||||
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
|
||||
|
||||
with gr.Row():
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
create_button = gr.Button(
|
||||
label="Start",
|
||||
value="Open drawing canvas!",
|
||||
visible=False,
|
||||
)
|
||||
create_button.click(
|
||||
fn=create_canvas,
|
||||
inputs=[canvas_width, canvas_height],
|
||||
outputs=[img2img_init_image],
|
||||
)
|
||||
use_stencil.change(
|
||||
fn=show_canvas,
|
||||
inputs=use_stencil,
|
||||
outputs=[canvas_width, canvas_height, create_button],
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
@@ -507,10 +589,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
@@ -523,17 +605,12 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
img2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
img2img_sendto_outpaint = gr.Button(
|
||||
@@ -558,8 +635,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -571,13 +648,21 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output],
|
||||
outputs=[img2img_gallery, std_output, img2img_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Image-to-Image", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=img2img_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import inpaint_inf
|
||||
from apps.stable_diffusion.src import args
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -13,6 +16,284 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def inpaint_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image_dict,
|
||||
height: int,
|
||||
width: int,
|
||||
inpaint_full_res: bool,
|
||||
inpaint_full_res_padding: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.img_path = "not none"
|
||||
args.mask_path = "not none"
|
||||
args.ondemand = ondemand
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
"inpaint",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.use_tuned = init_use_tuned
|
||||
args.import_mlir = init_import_mlir
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
InpaintPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
custom_vae=args.custom_vae,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
image = image_dict["image"]
|
||||
mask_image = image_dict["mask"]
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
mask_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
inpaint_full_res,
|
||||
inpaint_full_res_padding,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Inpaint", i + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Inpaint Rest API.
|
||||
def inpaint_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["image"])
|
||||
mask = decode_base64_to_image(InputData["mask"])
|
||||
res = inpaint_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
{"image": init_image, "mask": mask},
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["is_full_res"],
|
||||
InputData["full_res_padding"],
|
||||
InputData["steps"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
@@ -30,32 +311,32 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
inpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
inpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
lines=3,
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"Custom Vae Models",
|
||||
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.custom_vae)
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_paint_models,
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -203,10 +484,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
@@ -219,17 +500,13 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
inpaint_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
inpaint_sendto_outpaint = gr.Button(
|
||||
@@ -255,8 +532,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -267,13 +544,20 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[inpaint_gallery, std_output],
|
||||
outputs=[inpaint_gallery, std_output, inpaint_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=inpaint_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
@@ -159,10 +159,10 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
train_lora = gr.Button("Train LoRA")
|
||||
|
||||
157
apps/stable_diffusion/web/ui/model_manager.py
Normal file
157
apps/stable_diffusion/web/ui/model_manager.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import os
|
||||
import gradio as gr
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_hf_list(num_of_models=20):
|
||||
path = "https://huggingface.co/api/models"
|
||||
params = {
|
||||
"search": "stable-diffusion",
|
||||
"sort": "downloads",
|
||||
"direction": "-1",
|
||||
"limit": {num_of_models},
|
||||
"full": "true",
|
||||
}
|
||||
response = requests.get(path, params=params)
|
||||
return response.json()
|
||||
|
||||
|
||||
def get_civit_list(num_of_models=50):
|
||||
path = f"https://civitai.com/api/v1/models?limit={num_of_models}&types=Checkpoint"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
raw_json = requests.get(path, headers=headers).json()
|
||||
models = list(raw_json.items())[0][1]
|
||||
safe_models = [
|
||||
safe_model for safe_model in models if not safe_model["nsfw"]
|
||||
]
|
||||
version_id = 0 # Currently just using the first version.
|
||||
safe_models = [
|
||||
safe_model
|
||||
for safe_model in safe_models
|
||||
if safe_model["modelVersions"][version_id]["files"][0]["metadata"][
|
||||
"format"
|
||||
]
|
||||
== "SafeTensor"
|
||||
]
|
||||
first_version_models = []
|
||||
for model_iter in safe_models:
|
||||
# The modelVersion would only keep the version name.
|
||||
if (
|
||||
model_iter["modelVersions"][version_id]["images"][0]["nsfw"]
|
||||
!= "None"
|
||||
):
|
||||
continue
|
||||
model_iter["modelVersions"][version_id]["modelName"] = model_iter[
|
||||
"name"
|
||||
]
|
||||
model_iter["modelVersions"][version_id]["rating"] = model_iter[
|
||||
"stats"
|
||||
]["rating"]
|
||||
model_iter["modelVersions"][version_id]["favoriteCount"] = model_iter[
|
||||
"stats"
|
||||
]["favoriteCount"]
|
||||
model_iter["modelVersions"][version_id]["downloadCount"] = model_iter[
|
||||
"stats"
|
||||
]["downloadCount"]
|
||||
first_version_models.append(model_iter["modelVersions"][version_id])
|
||||
return first_version_models
|
||||
|
||||
|
||||
def get_image_from_model(model_json):
|
||||
model_id = model_json["modelId"]
|
||||
image = None
|
||||
for img_info in model_json["images"]:
|
||||
if img_info["nsfw"] == "None":
|
||||
image_url = model_json["images"][0]["url"]
|
||||
response = requests.get(image_url)
|
||||
image = BytesIO(response.content)
|
||||
break
|
||||
return image
|
||||
|
||||
|
||||
with gr.Blocks() as model_web:
|
||||
with gr.Row():
|
||||
model_source = gr.Radio(
|
||||
value=None,
|
||||
choices=["Hugging Face", "Civitai"],
|
||||
type="value",
|
||||
label="Model Source",
|
||||
)
|
||||
model_numebr = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Number of models",
|
||||
interactive=True,
|
||||
)
|
||||
# TODO: add more filters
|
||||
get_model_btn = gr.Button(value="Get Models")
|
||||
|
||||
hf_models = gr.Dropdown(
|
||||
label="Hugging Face Model List",
|
||||
choices=None,
|
||||
value=None,
|
||||
visible=False,
|
||||
)
|
||||
# TODO: select and SendTo
|
||||
civit_models = gr.Gallery(
|
||||
label="Civitai Model Gallery",
|
||||
value=None,
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False) as sendto_btns:
|
||||
modelmanager_sendto_txt2img = gr.Button(value="SendTo Txt2Img")
|
||||
modelmanager_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
modelmanager_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
|
||||
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
|
||||
|
||||
def get_model_list(model_source, model_numebr):
|
||||
if model_source == "Hugging Face":
|
||||
hf_model_list = get_hf_list(model_numebr)
|
||||
models = []
|
||||
for model in hf_model_list:
|
||||
# TODO: add model info
|
||||
models.append(f'{model["modelId"]}')
|
||||
return (
|
||||
gr.Dropdown.update(choices=models, visible=True),
|
||||
gr.Gallery.update(value=None, visible=False),
|
||||
gr.Row.update(visible=True),
|
||||
)
|
||||
elif model_source == "Civitai":
|
||||
civit_model_list = get_civit_list(model_numebr)
|
||||
models = []
|
||||
for model in civit_model_list:
|
||||
image = get_image_from_model(model)
|
||||
if image is None:
|
||||
continue
|
||||
# TODO: add model info
|
||||
models.append(
|
||||
(Image.open(image), f'{model["files"][0]["downloadUrl"]}')
|
||||
)
|
||||
return (
|
||||
gr.Dropdown.update(value=None, choices=None, visible=False),
|
||||
gr.Gallery.update(value=models, visible=True),
|
||||
gr.Row.update(visible=False),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
gr.Dropdown.update(value=None, choices=None, visible=False),
|
||||
gr.Gallery.update(value=None, visible=False),
|
||||
gr.Row.update(visible=False),
|
||||
)
|
||||
|
||||
get_model_btn.click(
|
||||
fn=get_model_list,
|
||||
inputs=[model_source, model_numebr],
|
||||
outputs=[
|
||||
hf_models,
|
||||
civit_models,
|
||||
sendto_btns,
|
||||
],
|
||||
)
|
||||
@@ -1,9 +1,11 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import outpaint_inf
|
||||
from apps.stable_diffusion.src import args
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -13,6 +15,294 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
OutpaintPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def outpaint_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image,
|
||||
pixels: int,
|
||||
mask_blur: int,
|
||||
directions: list,
|
||||
noise_q: float,
|
||||
color_variation: float,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.img_path = "not none"
|
||||
args.ondemand = ondemand
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
"outpaint",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.use_tuned = init_use_tuned
|
||||
args.import_mlir = init_import_mlir
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
OutpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
|
||||
left = True if "left" in directions else False
|
||||
right = True if "right" in directions else False
|
||||
top = True if "up" in directions else False
|
||||
bottom = True if "down" in directions else False
|
||||
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
init_image,
|
||||
pixels,
|
||||
mask_blur,
|
||||
left,
|
||||
right,
|
||||
top,
|
||||
bottom,
|
||||
noise_q,
|
||||
color_variation,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Outpaint", i + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Inpaint Rest API.
|
||||
def outpaint_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = outpaint_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
init_image,
|
||||
InputData["pixels"],
|
||||
InputData["mask_blur"],
|
||||
InputData["directions"],
|
||||
InputData["noise_q"],
|
||||
InputData["color_variation"],
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
@@ -30,32 +320,32 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
outpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
outpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
lines=3,
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"Custom Vae Models",
|
||||
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.custom_vae)
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_paint_models,
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -222,10 +512,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
@@ -238,17 +528,12 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
outpaint_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -275,8 +560,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -287,13 +572,20 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[outpaint_gallery, std_output],
|
||||
outputs=[outpaint_gallery, std_output, outpaint_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Outpaint", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=outpaint_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
415
apps/stable_diffusion/web/ui/outputgallery_ui.py
Normal file
415
apps/stable_diffusion/web/ui/outputgallery_ui.py
Normal file
@@ -0,0 +1,415 @@
|
||||
import glob
|
||||
import gradio as gr
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
|
||||
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
|
||||
|
||||
# -- Functions for file, directory and image info querying
|
||||
|
||||
output_dir = get_generated_imgs_path()
|
||||
|
||||
|
||||
def outputgallery_filenames(subdir) -> list[str]:
|
||||
new_dir_path = os.path.join(output_dir, subdir)
|
||||
if os.path.exists(new_dir_path):
|
||||
filenames = [
|
||||
glob.glob(new_dir_path + "/" + ext)
|
||||
for ext in ("*.png", "*.jpg", "*.jpeg")
|
||||
]
|
||||
|
||||
return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True)
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def output_subdirs() -> list[str]:
|
||||
# Gets a list of subdirectories of output_dir and below, as relative paths.
|
||||
relative_paths = [
|
||||
os.path.relpath(entry[0], output_dir)
|
||||
for entry in os.walk(
|
||||
output_dir, followlinks=args.output_gallery_followlinks
|
||||
)
|
||||
]
|
||||
|
||||
# It is less confusing to always including the subdir that will take any images generated
|
||||
# today even if it doesn't exist yet
|
||||
if get_generated_imgs_todays_subdir() not in relative_paths:
|
||||
relative_paths.append(get_generated_imgs_todays_subdir())
|
||||
|
||||
# sort subdirectories so that that the date named ones we probably created in this or
|
||||
# previous sessions come first, sorted with the most recent first. Other subdirs are listed
|
||||
# after.
|
||||
generated_paths = sorted(
|
||||
[path for path in relative_paths if path.isnumeric()], reverse=True
|
||||
)
|
||||
result_paths = generated_paths + sorted(
|
||||
[
|
||||
path
|
||||
for path in relative_paths
|
||||
if (not path.isnumeric()) and path != "."
|
||||
]
|
||||
)
|
||||
|
||||
return result_paths
|
||||
|
||||
|
||||
# --- Define UI layout for Gradio
|
||||
|
||||
with gr.Blocks() as outputgallery_web:
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
|
||||
with gr.Row(elem_id="outputgallery_gallery"):
|
||||
# needed to workaround gradio issue: https://github.com/gradio-app/gradio/issues/2907
|
||||
dev_null = gr.Textbox("", visible=False)
|
||||
|
||||
gallery_files = gr.State(value=[])
|
||||
subdirectory_paths = gr.State(value=[])
|
||||
|
||||
with gr.Column(scale=6):
|
||||
logo = gr.Image(
|
||||
label="Getting subdirectories...",
|
||||
value=nod_logo,
|
||||
interactive=False,
|
||||
visible=True,
|
||||
show_label=True,
|
||||
elem_id="top_logo",
|
||||
elem_classes="logo_centered",
|
||||
)
|
||||
|
||||
gallery = gr.Gallery(
|
||||
label="",
|
||||
value=gallery_files.value,
|
||||
visible=False,
|
||||
show_label=True,
|
||||
).style(columns=4)
|
||||
|
||||
with gr.Column(scale=4):
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column(scale=16, min_width=160):
|
||||
subdirectories = gr.Dropdown(
|
||||
label=f"Subdirectories of {output_dir}",
|
||||
type="value",
|
||||
choices=subdirectory_paths.value,
|
||||
value="",
|
||||
interactive=True,
|
||||
).style(container=False)
|
||||
with gr.Column(
|
||||
scale=1, min_width=32, elem_id="output_refresh_button"
|
||||
):
|
||||
refresh = gr.Button(
|
||||
variant="secondary",
|
||||
value="\u21BB", # unicode clockwise arrow circle
|
||||
).style(size="sm")
|
||||
|
||||
image_columns = gr.Slider(
|
||||
label="Columns shown", value=4, minimum=1, maximum=16, step=1
|
||||
)
|
||||
outputgallery_filename = gr.Textbox(
|
||||
label="Filename", value="None", interactive=False
|
||||
).style(show_copy_button=True)
|
||||
|
||||
with gr.Accordion(
|
||||
label="Parameter Information", open=False
|
||||
) as parameters_accordian:
|
||||
image_parameters = gr.DataFrame(
|
||||
headers=["Parameter", "Value"],
|
||||
col_count=2,
|
||||
wrap=True,
|
||||
elem_classes="output_parameters_dataframe",
|
||||
value=[["Status", "No image selected"]],
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Send To", open=True):
|
||||
with gr.Row():
|
||||
outputgallery_sendto_txt2img = gr.Button(
|
||||
value="Txt2Img",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
outputgallery_sendto_img2img = gr.Button(
|
||||
value="Img2Img",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
outputgallery_sendto_inpaint = gr.Button(
|
||||
value="Inpaint",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
outputgallery_sendto_outpaint = gr.Button(
|
||||
value="Outpaint",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
outputgallery_sendto_upscaler = gr.Button(
|
||||
value="Upscaler",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
).style(size="sm")
|
||||
|
||||
# --- Event handlers
|
||||
|
||||
def on_clear_gallery():
|
||||
return [
|
||||
gr.Gallery.update(
|
||||
value=[],
|
||||
visible=False,
|
||||
),
|
||||
gr.Image.update(
|
||||
visible=True,
|
||||
),
|
||||
]
|
||||
|
||||
def on_select_subdir(subdir) -> list:
|
||||
# evt.value is the subdirectory name
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = (
|
||||
f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
|
||||
)
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
|
||||
def on_refresh(current_subdir: str) -> list:
|
||||
# get an up to date subdirectory list
|
||||
refreshed_subdirs = output_subdirs()
|
||||
# get the images using either the current subdirectory or the most recent valid one
|
||||
new_subdir = (
|
||||
current_subdir
|
||||
if current_subdir in refreshed_subdirs
|
||||
else refreshed_subdirs[0]
|
||||
)
|
||||
new_images = outputgallery_filenames(new_subdir)
|
||||
new_label = f"{len(new_images)} images in {os.path.join(output_dir, new_subdir)}"
|
||||
|
||||
return [
|
||||
gr.Dropdown.update(
|
||||
choices=refreshed_subdirs,
|
||||
value=new_subdir,
|
||||
),
|
||||
refreshed_subdirs,
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
value=new_images, label=new_label, visible=len(new_images) > 0
|
||||
),
|
||||
gr.Image.update(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
|
||||
def on_new_image(subdir, subdir_paths, status) -> list:
|
||||
# prevent error triggered when an image generates before the tab has even been selected
|
||||
subdir_paths = (
|
||||
subdir_paths
|
||||
if len(subdir_paths) > 0
|
||||
else [get_generated_imgs_todays_subdir()]
|
||||
)
|
||||
|
||||
# only update if the current subdir is the most recent one as new images only go there
|
||||
if subdir_paths[0] == subdir:
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"
|
||||
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
else:
|
||||
# otherwise change nothing, (only untyped gradio gr.update() does this)
|
||||
return [gr.update(), gr.update(), gr.update()]
|
||||
|
||||
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
|
||||
# evt.index is an index into the full list of filenames for the current subdirectory
|
||||
filename = images[evt.index]
|
||||
params = displayable_metadata(filename)
|
||||
|
||||
if params:
|
||||
return [
|
||||
filename,
|
||||
list(map(list, params["parameters"].items())),
|
||||
]
|
||||
|
||||
return [
|
||||
filename,
|
||||
[["Status", "No parameters found"]],
|
||||
]
|
||||
|
||||
def on_outputgallery_filename_change(filename: str) -> list:
|
||||
exists = filename != "None" and os.path.exists(filename)
|
||||
return [
|
||||
# disable or enable each of the sendto button based on whether an image is selected
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
]
|
||||
|
||||
# The time first our tab is selected we need to do an initial refresh to populate
|
||||
# the subdirectory select box and the images from the most recent subdirectory.
|
||||
#
|
||||
# We do it at this point rather than setting this up in the controls' definitions
|
||||
# as when you refresh the browser you always get what was *initially* set, which
|
||||
# won't include any new subdirectories or images that might have created since
|
||||
# the application was started. Doing it this way means a browser refresh/reload
|
||||
# always gets the most up to date data.
|
||||
def on_select_tab(subdir_paths):
|
||||
if len(subdir_paths) == 0:
|
||||
return on_refresh("")
|
||||
else:
|
||||
return (
|
||||
# Change nothing, (only untyped gr.update() does this)
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
# Unfortunately as of gradio 3.22.0 gr.update against Galleries doesn't support
|
||||
# things set with .style, nor the elem_classes kwarg so we have to directly set
|
||||
# things up via JavaScript if we want the client to take notice of any of our
|
||||
# changes to the number of columns after it decides to put them back to the
|
||||
# original number when we change something
|
||||
def js_set_columns_in_browser(timeout_length):
|
||||
return f"""
|
||||
(new_cols) => {{
|
||||
setTimeout(() => {{
|
||||
required_style = "auto ".repeat(new_cols).trim();
|
||||
gallery = document.querySelector('#outputgallery_gallery .grid-container');
|
||||
if (gallery) {{
|
||||
gallery.style.gridTemplateColumns = required_style
|
||||
}}
|
||||
}}, {timeout_length});
|
||||
return []; // prevents console error from gradio
|
||||
}}
|
||||
"""
|
||||
|
||||
# --- Wire handlers up to the actions
|
||||
|
||||
# - Many actions reset the number of columns shown in the gallery on the browser end,
|
||||
# so we have to set them back to what we think they should be after the initial
|
||||
# action.
|
||||
# - None of the actions on this tab trigger inference, and we want the user to be able
|
||||
# to do them whilst other tabs have ongoing inference running. Waiting in the queue
|
||||
# behind inference jobs would mean the UI can't fully respond until the inference tasks
|
||||
# complete, hence queue=False on all of these.
|
||||
set_gallery_columns_immediate = dict(
|
||||
fn=None,
|
||||
inputs=[image_columns],
|
||||
# gradio blanks the UI on Chrome on Linux on gallery select if I don't put an output here
|
||||
outputs=[dev_null],
|
||||
_js=js_set_columns_in_browser(0),
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# setting columns after selecting a gallery item needs a real timeout length for the
|
||||
# number of columns to actually be applied. Not really sure why, maybe something has
|
||||
# to finish animating?
|
||||
set_gallery_columns_delayed = dict(
|
||||
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
|
||||
)
|
||||
|
||||
# clearing images when we need to completely change what's in the gallery avoids current
|
||||
# images being shown replacing piecemeal and prevents weirdness and errors if the user
|
||||
# selects an image during the replacement phase.
|
||||
clear_gallery = dict(
|
||||
fn=on_clear_gallery,
|
||||
inputs=None,
|
||||
outputs=[gallery, logo],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
image_columns.change(**set_gallery_columns_immediate)
|
||||
|
||||
subdirectories.select(**clear_gallery).then(
|
||||
on_select_subdir,
|
||||
[subdirectories],
|
||||
[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
refresh.click(**clear_gallery).then(
|
||||
on_refresh,
|
||||
[subdirectories],
|
||||
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
gallery.select(
|
||||
on_select_image,
|
||||
[gallery_files],
|
||||
[outputgallery_filename, image_parameters],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_delayed)
|
||||
|
||||
outputgallery_filename.change(
|
||||
on_outputgallery_filename_change,
|
||||
[outputgallery_filename],
|
||||
[
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
outputgallery_sendto_upscaler,
|
||||
],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# We should have been given the .select function for our tab, so set it up
|
||||
def outputgallery_tab_select(select):
|
||||
select(
|
||||
fn=on_select_tab,
|
||||
inputs=[subdirectory_paths],
|
||||
outputs=[
|
||||
subdirectories,
|
||||
subdirectory_paths,
|
||||
gallery_files,
|
||||
gallery,
|
||||
logo,
|
||||
],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
# We should have been passed a list of components on other tabs that update
|
||||
# when a new image has generated on that tab, so set things up so the user
|
||||
# will see that new image if they are looking at today's subdirectory
|
||||
def outputgallery_watch(components: gr.Textbox):
|
||||
for component in components:
|
||||
component.change(
|
||||
on_new_image,
|
||||
inputs=[subdirectories, subdirectory_paths, component],
|
||||
outputs=[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
185
apps/stable_diffusion/web/ui/stablelm_ui.py
Normal file
185
apps/stable_diffusion/web/ui/stablelm_ui.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import gradio as gr
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
|
||||
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
sharkModel = 0
|
||||
sharded_model = 0
|
||||
vicuna_model = 0
|
||||
|
||||
|
||||
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
past_key_values = None
|
||||
|
||||
|
||||
def chat(curr_system_message, history, model, device, precision):
|
||||
print(f"In chat for {model}")
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
global vicuna_model
|
||||
if "vicuna" in model:
|
||||
from apps.language_models.src.pipelines.vicuna_pipeline import (
|
||||
Vicuna,
|
||||
)
|
||||
|
||||
curr_system_message = start_message_vicuna
|
||||
if vicuna_model == 0:
|
||||
first_vic_vmfb_path = Path("first_vicuna.vmfb")
|
||||
second_vic_vmfb_path = Path("second_vicuna.vmfb")
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
vicuna_model = Vicuna(
|
||||
"vicuna",
|
||||
hf_model_path=model,
|
||||
device=device,
|
||||
precision=precision,
|
||||
first_vicuna_vmfb_path=first_vic_vmfb_path,
|
||||
second_vicuna_vmfb_path=second_vic_vmfb_path,
|
||||
)
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
prompt = messages.strip()
|
||||
print("prompt = ", prompt)
|
||||
sentence = vicuna_model.generate(prompt)
|
||||
|
||||
partial_text = ""
|
||||
for new_text in sentence.split(" "):
|
||||
# print(new_text)
|
||||
partial_text += new_text + " "
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to cleanup the message textbox and the updated conversation history
|
||||
yield history
|
||||
history[-1][1] = sentence
|
||||
return history
|
||||
|
||||
# else Model is StableLM
|
||||
global sharkModel
|
||||
from apps.language_models.src.pipelines.stablelm_pipeline import (
|
||||
SharkStableLM,
|
||||
)
|
||||
|
||||
if sharkModel == 0:
|
||||
# max_new_tokens=512
|
||||
shark_slm = SharkStableLM(
|
||||
"StableLM"
|
||||
) # pass elements from UI as required
|
||||
|
||||
# Construct the input message string for the model by concatenating the current system message and conversation history
|
||||
if len(curr_system_message.split()) > 160:
|
||||
print("clearing context")
|
||||
curr_system_message = start_message
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
|
||||
generate_kwargs = dict(prompt=messages)
|
||||
|
||||
words_list = shark_slm.generate(**generate_kwargs)
|
||||
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
# print(new_text)
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to cleanup the message textbox and the updated conversation history
|
||||
yield history
|
||||
return words_list
|
||||
|
||||
|
||||
with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
with gr.Row():
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value="TheBloke/vicuna-7B-1.1-HF",
|
||||
choices=[
|
||||
"stabilityai/stablelm-tuned-alpha-3b",
|
||||
"TheBloke/vicuna-7B-1.1-HF",
|
||||
],
|
||||
)
|
||||
supported_devices = [
|
||||
device for device in available_devices if "cuda" in device
|
||||
]
|
||||
enabled = len(supported_devices) > 0
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp32",
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
chatbot = gr.Chatbot().style(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
interactive=enabled,
|
||||
).style(container=False)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
system_msg = gr.Textbox(
|
||||
start_message, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[submit_event, submit_click_event],
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
@@ -1,9 +1,12 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -13,7 +16,8 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.png_metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImagePipeline,
|
||||
@@ -23,7 +27,10 @@ from apps.stable_diffusion.src import (
|
||||
save_output_img,
|
||||
prompt_examples,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
@@ -74,18 +81,23 @@ def txt2img_inf(
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.custom_vae = custom_vae
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -100,6 +112,7 @@ def txt2img_inf(
|
||||
"txt2img",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
@@ -180,6 +193,7 @@ def txt2img_inf(
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
@@ -192,9 +206,68 @@ def txt2img_inf(
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Text-to-Image", i + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Text2Img Rest API.
|
||||
def txt2img_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
)
|
||||
res = txt2img_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
@@ -214,35 +287,34 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
txt2img_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
txt2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
lines=3,
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"Custom Vae Models",
|
||||
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.custom_vae)
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
+ get_custom_model_files("vae"),
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
png_info_img = gr.Image(
|
||||
txt2img_png_info_img = gr.Image(
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
@@ -380,10 +452,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
@@ -404,17 +476,12 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
txt2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -438,8 +505,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -450,22 +517,30 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[txt2img_gallery, std_output],
|
||||
outputs=[txt2img_gallery, std_output, txt2img_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Text-to-Image", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=txt2img_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
png_info_img.change(
|
||||
txt2img_png_info_img.change(
|
||||
fn=import_png_metadata,
|
||||
inputs=[
|
||||
png_info_img,
|
||||
txt2img_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
@@ -474,11 +549,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
],
|
||||
outputs=[
|
||||
png_info_img,
|
||||
txt2img_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
@@ -487,7 +562,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import upscaler_inf
|
||||
from apps.stable_diffusion.src import args
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -11,7 +13,300 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_files,
|
||||
scheduler_list_cpu_only,
|
||||
predefined_upscaler_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
UpscalerPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generated_imgs_path
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def upscaler_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
noise_level: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.seed = seed
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.ondemand = ondemand
|
||||
|
||||
if init_image is None:
|
||||
return None, "An Initial Image is required"
|
||||
image = init_image.convert("RGB").resize((height, width))
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
args.height = 128
|
||||
args.width = 128
|
||||
new_config_obj = Config(
|
||||
"upscaler",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
args.height,
|
||||
args.width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.use_tuned = init_use_tuned
|
||||
args.import_mlir = init_import_mlir
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
UpscalerPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
|
||||
"DDPM"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
extra_info = {"NOISE LEVEL": noise_level}
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
low_res_img = image
|
||||
high_res_img = Image.new("RGB", (height * 4, width * 4))
|
||||
|
||||
for i in range(0, width, 128):
|
||||
for j in range(0, height, 128):
|
||||
box = (j, i, j + 128, i + 128)
|
||||
upscaled_image = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
low_res_img.crop(box),
|
||||
batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
steps,
|
||||
noise_level,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(high_res_img, img_seed, extra_info)
|
||||
generated_imgs.append(high_res_img)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log, status_label(
|
||||
"Upscaler", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
yield generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Upscaler Rest API.
|
||||
def upscaler_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = upscaler_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
init_image,
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["noise_level"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
)
|
||||
# Converts generator type to subscriptable
|
||||
res = list(res)[0]
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
@@ -29,32 +324,32 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
upscaler_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
else "stabilityai/stable-diffusion-x4-upscaler",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="upscaler"
|
||||
)
|
||||
+ predefined_upscaler_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
upscaler_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
lines=3,
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"Custom Vae Models",
|
||||
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.custom_vae)
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_upscaler_models,
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -200,10 +495,10 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
@@ -216,17 +511,13 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
value=f"Images will be saved at {get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
upscaler_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -249,8 +540,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -261,13 +552,21 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
],
|
||||
outputs=[upscaler_gallery, std_output],
|
||||
outputs=[upscaler_gallery, std_output, upscaler_status],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Upscaler", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=upscaler_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ class Config:
|
||||
mode: str
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
custom_vae: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
@@ -71,30 +72,36 @@ def resource_path(relative_path):
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def create_custom_models_folders():
|
||||
dir = ["vae", "lora"]
|
||||
if not args.ckpt_dir:
|
||||
dir.insert(0, "models")
|
||||
else:
|
||||
if not os.path.isdir(args.ckpt_dir):
|
||||
sys.exit(
|
||||
f"Invalid --ckpt_dir argument, {args.ckpt_dir} folder does not exists."
|
||||
)
|
||||
for root in dir:
|
||||
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_custom_model_path(model="models"):
|
||||
# If `--ckpt_dir` is provided it'd override the heirarchical folder
|
||||
# structure in WebUI :-
|
||||
# model
|
||||
# models or args.ckpt_dir
|
||||
# |___lora
|
||||
# |___vae
|
||||
sub_folder = "" if model == "models" else model
|
||||
if args.ckpt_dir:
|
||||
return Path(args.ckpt_dir)
|
||||
match model:
|
||||
case "models":
|
||||
return Path(Path.cwd(), "models")
|
||||
case "vae":
|
||||
return Path(Path.cwd(), "models/vae")
|
||||
case "lora":
|
||||
return Path(Path.cwd(), "models/lora")
|
||||
case _:
|
||||
return ""
|
||||
return Path(Path(args.ckpt_dir), sub_folder)
|
||||
else:
|
||||
return Path(Path.cwd(), "models/" + sub_folder)
|
||||
|
||||
|
||||
def get_custom_model_pathfile(custom_model_name, model="models"):
|
||||
return os.path.join(get_custom_model_path(model), custom_model_name)
|
||||
|
||||
|
||||
def get_custom_model_files(model="models"):
|
||||
def get_custom_model_files(model="models", custom_checkpoint_type=""):
|
||||
ckpt_files = []
|
||||
file_types = custom_model_filetypes
|
||||
if model == "lora":
|
||||
@@ -106,6 +113,28 @@ def get_custom_model_files(model="models"):
|
||||
os.path.join(get_custom_model_path(model), extn)
|
||||
)
|
||||
]
|
||||
match custom_checkpoint_type:
|
||||
case "inpainting":
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if val.endswith("inpainting" + extn.removeprefix("*"))
|
||||
]
|
||||
case "upscaler":
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if val.endswith("upscaler" + extn.removeprefix("*"))
|
||||
]
|
||||
case _:
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if not (
|
||||
val.endswith("inpainting" + extn.removeprefix("*"))
|
||||
or val.endswith("upscaler" + extn.removeprefix("*"))
|
||||
)
|
||||
]
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
9
apps/stable_diffusion/web/utils/common_label_calc.py
Normal file
9
apps/stable_diffusion/web/utils/common_label_calc.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# functions for generating labels used in common by tabs across the UI
|
||||
|
||||
|
||||
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
|
||||
if batch_index < batch_count:
|
||||
bs = f"x{batch_size}" if batch_size > 1 else ""
|
||||
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
|
||||
else:
|
||||
return f"{tab_name} complete"
|
||||
@@ -1,31 +1,54 @@
|
||||
import os
|
||||
import tempfile
|
||||
import gradio
|
||||
from os import listdir
|
||||
import shutil
|
||||
from time import time
|
||||
|
||||
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
|
||||
|
||||
# Clear all gradio tmp images
|
||||
def clear_gradio_tmp_imgs_folder():
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
return
|
||||
for fileName in listdir(gradio_tmp_imgs_folder):
|
||||
# Delete tmp png files
|
||||
if fileName.startswith("tmp") and fileName.endswith(".png"):
|
||||
os.remove(gradio_tmp_imgs_folder + fileName)
|
||||
def config_gradio_tmp_imgs_folder():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
|
||||
# tell gradio to use a directory under shark_tmp for its temporary
|
||||
# image files unless somewhere else has been set
|
||||
if "GRADIO_TEMP_DIR" not in os.environ:
|
||||
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
|
||||
|
||||
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
os.mkdir(gradio_tmp_imgs_folder)
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
|
||||
print(
|
||||
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
|
||||
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
|
||||
)
|
||||
pil_image.save(file_obj)
|
||||
return file_obj
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
|
||||
cleanup_start = time()
|
||||
print(
|
||||
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
|
||||
)
|
||||
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
|
||||
print(
|
||||
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
|
||||
# Register save_pil_to_file override
|
||||
gradio.processing_utils.save_pil_to_file = save_pil_to_file
|
||||
# older SHARK versions had to workaround gradio bugs and stored things differently
|
||||
else:
|
||||
image_files = [
|
||||
filename
|
||||
for filename in os.listdir(shark_tmp)
|
||||
if os.path.isfile(os.path.join(shark_tmp, filename))
|
||||
and filename.startswith("tmp")
|
||||
and filename.endswith(".png")
|
||||
]
|
||||
if len(image_files) > 0:
|
||||
print(
|
||||
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
|
||||
)
|
||||
cleanup_start = time()
|
||||
for filename in image_files:
|
||||
os.remove(shark_tmp + filename)
|
||||
print(
|
||||
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
else:
|
||||
print("No temporary images files to clear.")
|
||||
|
||||
6
apps/stable_diffusion/web/utils/metadata/__init__.py
Normal file
6
apps/stable_diffusion/web/utils/metadata/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .png_metadata import (
|
||||
import_png_metadata,
|
||||
)
|
||||
from .display import (
|
||||
displayable_metadata,
|
||||
)
|
||||
31
apps/stable_diffusion/web/utils/metadata/csv_metadata.py
Normal file
31
apps/stable_diffusion/web/utils/metadata/csv_metadata.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import csv
|
||||
import os
|
||||
from .format import humanize, humanizable
|
||||
|
||||
|
||||
def csv_path(image_filename: str):
|
||||
return os.path.join(os.path.dirname(image_filename), "imgs_details.csv")
|
||||
|
||||
|
||||
def has_csv(image_filename: str) -> bool:
|
||||
return os.path.exists(csv_path(image_filename))
|
||||
|
||||
|
||||
def parse_csv(image_filename: str):
|
||||
# We use a reader instead of a DictReader here for images_details.csv files due to the lack of
|
||||
# headers, and then match up the return list for each row with our guess at which column format
|
||||
# the file is using.
|
||||
|
||||
# we assume the final column of the csv has the original filename with full path and match that
|
||||
# against the image_filename. We then exclude the filename from the output, hence the -1's.
|
||||
csv_filename = csv_path(image_filename)
|
||||
|
||||
matches = [
|
||||
humanize(row)
|
||||
for row in csv.reader(open(csv_filename, "r", newline=""))
|
||||
if row
|
||||
and humanizable(row)
|
||||
and os.path.basename(image_filename) in row[-1]
|
||||
]
|
||||
|
||||
return matches[0] if matches else {}
|
||||
50
apps/stable_diffusion/web/utils/metadata/display.py
Normal file
50
apps/stable_diffusion/web/utils/metadata/display.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import json
|
||||
import os
|
||||
from PIL import Image
|
||||
from .png_metadata import parse_generation_parameters
|
||||
from .exif_metadata import has_exif, parse_exif
|
||||
from .csv_metadata import has_csv, parse_csv
|
||||
from .format import compact, humanize
|
||||
|
||||
|
||||
def displayable_metadata(image_filename: str) -> dict:
|
||||
pil_image = Image.open(image_filename)
|
||||
|
||||
# we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads,
|
||||
# and we go via that for SendTo, and is directly tied to the image)
|
||||
if "parameters" in pil_image.info:
|
||||
return {
|
||||
"source": "png",
|
||||
"parameters": compact(
|
||||
parse_generation_parameters(pil_image.info["parameters"])
|
||||
),
|
||||
}
|
||||
|
||||
# we have a matching json file (next most likely to be accurate when it's there)
|
||||
json_path = os.path.splitext(image_filename)[0] + ".json"
|
||||
if os.path.isfile(json_path):
|
||||
with open(json_path) as params_file:
|
||||
return {
|
||||
"source": "json",
|
||||
"parameters": compact(
|
||||
humanize(json.load(params_file), includes_filename=False)
|
||||
),
|
||||
}
|
||||
|
||||
# we have a CSV file so try that (can be different shapes, and it usually has no
|
||||
# headers/param names so of the things we we *know* have parameters, it's the
|
||||
# last resort)
|
||||
if has_csv(image_filename):
|
||||
params = parse_csv(image_filename)
|
||||
if params: # we might not have found the filename in the csv
|
||||
return {
|
||||
"source": "csv",
|
||||
"parameters": compact(params), # already humanized
|
||||
}
|
||||
|
||||
# EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something*
|
||||
if has_exif(image_filename):
|
||||
return {"source": "exif", "parameters": parse_exif(pil_image)}
|
||||
|
||||
# we've got nothing
|
||||
return None
|
||||
52
apps/stable_diffusion/web/utils/metadata/exif_metadata.py
Normal file
52
apps/stable_diffusion/web/utils/metadata/exif_metadata.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from PIL import Image
|
||||
from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS
|
||||
|
||||
|
||||
def has_exif(image_filename: str) -> bool:
|
||||
return True if Image.open(image_filename).getexif() else False
|
||||
|
||||
|
||||
def parse_exif(pil_image: Image) -> dict:
|
||||
img_exif = pil_image.getexif()
|
||||
|
||||
# See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594
|
||||
# I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I
|
||||
# I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a
|
||||
# dependency
|
||||
exif_tags = {
|
||||
TAGS.get(key, key): str(val)
|
||||
for (key, val) in img_exif.items()
|
||||
if key in TAGS
|
||||
and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo)
|
||||
and val
|
||||
and (not isinstance(val, bytes))
|
||||
and (not str(val).isspace())
|
||||
}
|
||||
|
||||
def try_get_ifd(ifd_id):
|
||||
try:
|
||||
return img_exif.get_ifd(ifd_id).items()
|
||||
except KeyError:
|
||||
return {}
|
||||
|
||||
ifd_tags = {
|
||||
TAGS.get(key, key): str(val)
|
||||
for ifd_id in IFD
|
||||
for (key, val) in try_get_ifd(ifd_id)
|
||||
if ifd_id != IFD.GPSInfo
|
||||
and key in TAGS
|
||||
and val
|
||||
and (not isinstance(val, bytes))
|
||||
and (not str(val).isspace())
|
||||
}
|
||||
|
||||
gps_tags = {
|
||||
GPSTAGS.get(key, key): str(val)
|
||||
for (key, val) in try_get_ifd(IFD.GPSInfo)
|
||||
if key in GPSTAGS
|
||||
and val
|
||||
and (not isinstance(val, bytes))
|
||||
and (not str(val).isspace())
|
||||
}
|
||||
|
||||
return {**exif_tags, **ifd_tags, **gps_tags}
|
||||
115
apps/stable_diffusion/web/utils/metadata/format.py
Normal file
115
apps/stable_diffusion/web/utils/metadata/format.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# As SHARK has evolved more columns have been added to images_details.csv. However, since
|
||||
# no version of the CSV has any headers (yet) we don't actually have anything within the
|
||||
# file that tells us which parameter each column is for. So this is a list of known patterns
|
||||
# indexed by length which is what we're going to have to use to guess which columns are the
|
||||
# right ones for the file we're looking at.
|
||||
|
||||
# The same ordering is used for JSON, but these do have key names, however they are not very
|
||||
# human friendly, nor do they match up with the what is written to the .png headers
|
||||
|
||||
# So these are functions to try and get something consistent out the raw input from all
|
||||
# these sources
|
||||
|
||||
PARAMS_FORMATS = {
|
||||
9: {
|
||||
"VARIANT": "Model",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"OUTPUT": "Filename",
|
||||
},
|
||||
10: {
|
||||
"MODEL": "Model",
|
||||
"VARIANT": "Variant",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"OUTPUT": "Filename",
|
||||
},
|
||||
12: {
|
||||
"VARIANT": "Model",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"HEIGHT": "Height",
|
||||
"WIDTH": "Width",
|
||||
"MAX_LENGTH": "Max Length",
|
||||
"OUTPUT": "Filename",
|
||||
},
|
||||
}
|
||||
|
||||
PARAMS_FORMAT_LONGEST = PARAMS_FORMATS[max(PARAMS_FORMATS.keys())]
|
||||
|
||||
|
||||
def compact(metadata: dict) -> dict:
|
||||
# we don't want to alter the original dictionary
|
||||
result = dict(metadata)
|
||||
|
||||
# discard the filename because we should already have it
|
||||
if result.keys() & {"Filename"}:
|
||||
result.pop("Filename")
|
||||
|
||||
# make showing the sizes more compact by using only one line each
|
||||
if result.keys() & {"Size-1", "Size-2"}:
|
||||
result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}"
|
||||
elif result.keys() & {"Height", "Width"}:
|
||||
result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}"
|
||||
|
||||
if result.keys() & {"Hires resize-1", "Hires resize-1"}:
|
||||
hires_y = result.pop("Hires resize-1")
|
||||
hires_x = result.pop("Hires resize-2")
|
||||
|
||||
if hires_x == 0 and hires_y == 0:
|
||||
result["Hires resize"] = "None"
|
||||
else:
|
||||
result["Hires resize"] = f"{hires_y}x{hires_x}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def humanizable(metadata: dict | list[str], includes_filename=True) -> dict:
|
||||
lookup_key = len(metadata) + (0 if includes_filename else 1)
|
||||
return lookup_key in PARAMS_FORMATS.keys()
|
||||
|
||||
|
||||
def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
|
||||
lookup_key = len(metadata) + (0 if includes_filename else 1)
|
||||
|
||||
# For lists we can only work based on the length, we have no other information
|
||||
if isinstance(metadata, list):
|
||||
if humanizable(metadata, includes_filename):
|
||||
return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata))
|
||||
else:
|
||||
raise KeyError(
|
||||
f"Humanize could not find the format for a parameter list of length {len(metadata)}"
|
||||
)
|
||||
|
||||
# For dictionaries we try to use the matching length parameter format if
|
||||
# available, otherwise we use the longest. Then we swap keys in the
|
||||
# metadata that match keys in the format for the friendlier name that we
|
||||
# have set in the format value
|
||||
if isinstance(metadata, dict):
|
||||
if humanizable(metadata, includes_filename):
|
||||
format = PARAMS_FORMATS[lookup_key]
|
||||
else:
|
||||
format = PARAMS_FORMAT_LONGEST
|
||||
|
||||
return {
|
||||
format[key]: value
|
||||
for (key, value) in metadata.items()
|
||||
if key in format.keys()
|
||||
}
|
||||
|
||||
raise TypeError("Can only humanize parameter lists or dictionaries")
|
||||
@@ -40,7 +40,7 @@ cmake --build build/
|
||||
*Prepare the model*
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
```
|
||||
*Prepare the input*
|
||||
|
||||
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
|
||||
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
|
||||
```
|
||||
VAE and Autoencoder are also available
|
||||
```bash
|
||||
# VAE
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
|
||||
|
||||
# CLIP Autoencoder
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
|
||||
```
|
||||
|
||||
@@ -21,7 +21,7 @@ endif()
|
||||
# Compile mnist.mlir to mnist.vmfb.
|
||||
set(_COMPILE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)
|
||||
set(_COMPILE_ARGS)
|
||||
list(APPEND _COMPILE_ARGS "--iree-input-type=mhlo")
|
||||
list(APPEND _COMPILE_ARGS "--iree-input-type=auto")
|
||||
list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu")
|
||||
list(APPEND _COMPILE_ARGS "${IREE_SOURCE_DIR}/samples/models/mnist.mlir")
|
||||
list(APPEND _COMPILE_ARGS "-o")
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# SHARK Annotator
|
||||
gradio==3.15.0
|
||||
gradio==3.34.0
|
||||
jsonlines
|
||||
|
||||
@@ -16,16 +16,19 @@ parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
transformers
|
||||
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
|
||||
diffusers
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
gradio==3.34.0
|
||||
altair
|
||||
omegaconf
|
||||
safetensors
|
||||
opencv-python
|
||||
scikit-image
|
||||
pytorch_lightning # for runwayml models
|
||||
tk
|
||||
pywebview
|
||||
sentencepiece
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
109
rest_api_tests/api_test.py
Normal file
109
rest_api_tests/api_test.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import requests
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def upscaler_test():
|
||||
# Define values here
|
||||
prompt = ""
|
||||
negative_prompt = ""
|
||||
seed = 2121991605
|
||||
height = 512
|
||||
width = 512
|
||||
steps = 50
|
||||
noise_level = 10
|
||||
cfg_scale = 7
|
||||
image_path = r"./rest_api_tests/dog.png"
|
||||
|
||||
# Converting Image to base64
|
||||
img_file = open(image_path, "rb")
|
||||
init_images = [
|
||||
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
|
||||
]
|
||||
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/upscaler"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"seed": seed,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"steps": steps,
|
||||
"noise_level": noise_level,
|
||||
"cfg_scale": cfg_scale,
|
||||
"init_images": init_images,
|
||||
}
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"response from server was : {res.status_code}")
|
||||
|
||||
|
||||
def img2img_test():
|
||||
# Define values here
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
height = 512
|
||||
width = 512
|
||||
steps = 50
|
||||
denoising_strength = 0.75
|
||||
cfg_scale = 7
|
||||
image_path = r"./rest_api_tests/dog.png"
|
||||
|
||||
# Converting Image to Base64
|
||||
img_file = open(image_path, "rb")
|
||||
init_images = [
|
||||
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
|
||||
]
|
||||
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/img2img"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"init_images": init_images,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"steps": steps,
|
||||
"denoising_strength": denoising_strength,
|
||||
"cfg_scale": cfg_scale,
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"response from server was : {res.status_code}")
|
||||
|
||||
print("Extracting response object")
|
||||
|
||||
# Uncomment below to save the picture
|
||||
|
||||
response_obj = res.json()
|
||||
img_b64 = response_obj.get("images", [False])[0] or response_obj.get(
|
||||
"image"
|
||||
)
|
||||
img_b2 = base64.b64decode(img_b64.replace("data:image/png;base64,", ""))
|
||||
im_file = BytesIO(img_b2)
|
||||
response_img = Image.open(im_file)
|
||||
print("Saving Response Image to: response_img")
|
||||
response_img.save(r"rest_api_tests/response_img.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
img2img_test()
|
||||
upscaler_test()
|
||||
BIN
rest_api_tests/dog.png
Normal file
BIN
rest_api_tests/dog.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.5 KiB |
@@ -2,9 +2,10 @@
|
||||
# Sets up a venv suitable for running samples.
|
||||
# e.g:
|
||||
# ./setup_venv.sh #setup a default $PYTHON3 shark.venv
|
||||
# Environment Variables by the script.
|
||||
# Environment variables used by the script.
|
||||
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
|
||||
# VENV_DIR=myshark.venv #create a venv called myshark.venv
|
||||
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
|
||||
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
|
||||
# IMPORTER=1 #Install importer deps
|
||||
# BENCHMARK=1 #Install benchmark deps
|
||||
@@ -26,15 +27,17 @@ PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; prin
|
||||
echo "Python: $PYTHON"
|
||||
echo "Python version: $PYTHON_VERSION_X_Y"
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
# Not a conda env. So create a new VENV dir
|
||||
VENV_DIR=${VENV_DIR:-shark.venv}
|
||||
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
|
||||
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
|
||||
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
|
||||
PYTHON="$(which python3)"
|
||||
else
|
||||
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
|
||||
if [[ "$SKIP_VENV" != "1" ]]; then
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
# Not a conda env. So create a new VENV dir
|
||||
VENV_DIR=${VENV_DIR:-shark.venv}
|
||||
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
|
||||
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
|
||||
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
|
||||
PYTHON="$(which python3)"
|
||||
else
|
||||
echo "Found conda env $CONDA_DEFAULT_ENV. Running pip install inside the conda env"
|
||||
fi
|
||||
fi
|
||||
|
||||
Red=`tput setaf 1`
|
||||
@@ -147,8 +150,7 @@ if [[ ! -z "${ONNX}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
|
||||
echo "${Green}Before running examples activate venv with:"
|
||||
echo " ${Green}source $VENV_DIR/bin/activate"
|
||||
fi
|
||||
|
||||
|
||||
76
shark/examples/shark_inference/mega_test.py
Normal file
76
shark/examples/shark_inference/mega_test.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from shark.shark_inference import SharkInference
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
args,
|
||||
)
|
||||
from MEGABYTE_pytorch import MEGABYTE
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class MegaModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MEGABYTE(
|
||||
num_tokens=16000, # number of tokens
|
||||
dim=(
|
||||
512,
|
||||
256,
|
||||
), # transformer model dimension (512 for coarsest, 256 for fine in this example)
|
||||
max_seq_len=(
|
||||
1024,
|
||||
4,
|
||||
), # sequence length for global and then local. this can be more than 2
|
||||
depth=(
|
||||
6,
|
||||
4,
|
||||
), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
|
||||
dim_head=64, # dimension per head
|
||||
heads=8, # number of attention heads
|
||||
flash_attn=True, # use flash attention
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
megaModel = MegaModel()
|
||||
input = [torch.randint(0, 16000, (1, 1024, 4))]
|
||||
|
||||
# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :-
|
||||
# 1. aten.alias
|
||||
shark_module, _ = compile_through_fx(
|
||||
megaModel,
|
||||
inputs=input,
|
||||
extended_model_name="mega_shark",
|
||||
debug=False,
|
||||
generate_vmfb=True,
|
||||
save_dir=os.getcwd(),
|
||||
extra_args=[],
|
||||
base_model_id=None,
|
||||
model_name="mega_shark",
|
||||
precision=None,
|
||||
return_mlir=True,
|
||||
device="cuda",
|
||||
)
|
||||
# logits = model(x)
|
||||
|
||||
|
||||
def print_output_info(output, msg):
|
||||
print("\n", msg)
|
||||
print("\n\t", output.shape)
|
||||
|
||||
|
||||
ans = shark_module("forward", input)
|
||||
print_output_info(torch.from_numpy(ans), "SHARK's output")
|
||||
|
||||
ans = megaModel.forward(*input)
|
||||
print_output_info(ans, "ORIGINAL Model's output")
|
||||
|
||||
# and sample from the logits accordingly
|
||||
# or you can use the generate function
|
||||
|
||||
# NEED TO LOOK AT THIS LATER IF REQUIRED IN SHARK.
|
||||
# sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
|
||||
73
shark/examples/shark_inference/minilm_jax.py
Normal file
73
shark/examples/shark_inference/minilm_jax.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from transformers import AutoTokenizer, FlaxAutoModel
|
||||
import torch
|
||||
import jax
|
||||
from typing import Union, Dict, List, Any
|
||||
import numpy as np
|
||||
from shark.shark_inference import SharkInference
|
||||
import io
|
||||
|
||||
NumpyTree = Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]
|
||||
|
||||
|
||||
def convert_torch_tensor_tree_to_numpy(
|
||||
tree: Union[torch.tensor, Dict[str, torch.tensor], List[torch.tensor]]
|
||||
) -> NumpyTree:
|
||||
return jax.tree_util.tree_map(
|
||||
lambda torch_tensor: torch_tensor.cpu().detach().numpy(), tree
|
||||
)
|
||||
|
||||
|
||||
def convert_int64_to_int32(tree: NumpyTree) -> NumpyTree:
|
||||
return jax.tree_util.tree_map(
|
||||
lambda tensor: np.array(tensor, dtype=np.int32)
|
||||
if tensor.dtype == np.int64
|
||||
else tensor,
|
||||
tree,
|
||||
)
|
||||
|
||||
|
||||
def get_sample_input():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
inputs_torch = tokenizer("Hello, World!", return_tensors="pt")
|
||||
return convert_int64_to_int32(
|
||||
convert_torch_tensor_tree_to_numpy(inputs_torch.data)
|
||||
)
|
||||
|
||||
|
||||
def get_jax_model():
|
||||
return FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
|
||||
|
||||
def export_jax_to_mlir(jax_model: Any, sample_input: NumpyTree):
|
||||
model_mlir = jax.jit(jax_model).lower(**sample_input).compiler_ir()
|
||||
byte_stream = io.BytesIO()
|
||||
model_mlir.operation.write_bytecode(file=byte_stream)
|
||||
return byte_stream.getvalue()
|
||||
|
||||
|
||||
def assert_array_list_allclose(x, y, *args, **kwargs):
|
||||
assert len(x) == len(y)
|
||||
for a, b in zip(x, y):
|
||||
np.testing.assert_allclose(
|
||||
np.asarray(a), np.asarray(b), *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
sample_input = get_sample_input()
|
||||
jax_model = get_jax_model()
|
||||
mlir = export_jax_to_mlir(jax_model, sample_input)
|
||||
|
||||
# Compile and load module.
|
||||
shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo")
|
||||
shark_inference.compile()
|
||||
|
||||
# Run main function.
|
||||
result = shark_inference("main", jax.tree_util.tree_flatten(sample_input)[0])
|
||||
|
||||
# Run JAX model.
|
||||
reference_result = jax.tree_util.tree_flatten(jax_model(**sample_input))[0]
|
||||
|
||||
# Verify result.
|
||||
assert_array_list_allclose(result, reference_result, atol=1e-5)
|
||||
@@ -0,0 +1,6 @@
|
||||
flax
|
||||
jax[cpu]
|
||||
nodai-SHARK
|
||||
orbax
|
||||
transformers
|
||||
torch
|
||||
@@ -70,11 +70,11 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"resnet50", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
|
||||
shark_module = SharkInference(mlir_model, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
path = shark_module.save_module()
|
||||
shark_module.load_module(path)
|
||||
result = shark_module.forward((img.detach().numpy(),))
|
||||
result = shark_module("forward", (img.detach().numpy(),))
|
||||
|
||||
print("The top 3 results obtained via shark_runner is:")
|
||||
print(top3_possibilities(torch.from_numpy(result)))
|
||||
|
||||
@@ -45,10 +45,15 @@ def run_cmd(cmd, debug=False):
|
||||
|
||||
def iree_device_map(device):
|
||||
uri_parts = device.split("://", 2)
|
||||
iree_driver = (
|
||||
_IREE_DEVICE_MAP[uri_parts[0]]
|
||||
if uri_parts[0] in _IREE_DEVICE_MAP
|
||||
else uri_parts[0]
|
||||
)
|
||||
if len(uri_parts) == 1:
|
||||
return _IREE_DEVICE_MAP[uri_parts[0]]
|
||||
return iree_driver
|
||||
else:
|
||||
return f"{_IREE_DEVICE_MAP[uri_parts[0]]}://{uri_parts[1]}"
|
||||
return f"{iree_driver}://{uri_parts[1]}"
|
||||
|
||||
|
||||
def get_supported_device_list():
|
||||
@@ -57,6 +62,8 @@ def get_supported_device_list():
|
||||
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "vulkan",
|
||||
@@ -68,11 +75,13 @@ _IREE_DEVICE_MAP = {
|
||||
def iree_target_map(device):
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
return _IREE_TARGET_MAP[device]
|
||||
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
|
||||
|
||||
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cpu-task": "llvm-cpu",
|
||||
"cpu-sync": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "vulkan",
|
||||
@@ -110,10 +119,8 @@ def check_device_drivers(device):
|
||||
subprocess.check_output("rocminfo")
|
||||
except Exception:
|
||||
return True
|
||||
# Unknown device.
|
||||
else:
|
||||
return True
|
||||
|
||||
# Unknown device. We assume drivers are installed.
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import re
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
print("Configuring for device:" + device)
|
||||
device_uri = device.split("://")
|
||||
if len(device_uri) > 1:
|
||||
if device_uri[0] not in ["vulkan"]:
|
||||
@@ -30,6 +31,9 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
f"Specific device selection only supported for vulkan now."
|
||||
f"Proceeding with {device} as device."
|
||||
)
|
||||
device_num = device_uri[1]
|
||||
else:
|
||||
device_num = 0
|
||||
|
||||
if device_uri[0] == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
@@ -42,7 +46,9 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] in ["metal", "vulkan"]:
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(extra_args=extra_args)
|
||||
return get_iree_vulkan_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device_uri[0] == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
@@ -54,10 +60,9 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
def get_iree_frontend_args(frontend):
|
||||
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
|
||||
return ["--iree-llvmcpu-target-cpu-features=host"]
|
||||
elif frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
return [
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
"--iree-flow-demote-i64-to-i32",
|
||||
]
|
||||
else:
|
||||
@@ -259,8 +264,8 @@ def compile_module_to_flatbuffer(
|
||||
args += extra_args
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "mhlo"
|
||||
elif frontend in ["mhlo", "tosa"]:
|
||||
input_type = "auto"
|
||||
elif frontend in ["stablehlo", "tosa"]:
|
||||
input_type = frontend
|
||||
elif frontend in ["tflite", "tflite-tosa"]:
|
||||
input_type = "tosa"
|
||||
@@ -307,7 +312,7 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
)
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
ModuleCompiled = ctx.modules.module
|
||||
ModuleCompiled = getattr(ctx.modules, vm_module.name)
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
@@ -361,7 +366,7 @@ def export_iree_module_to_vmfb(
|
||||
def export_module_to_mlir_file(module, frontend, directory: str):
|
||||
# TODO: write proper documentation.
|
||||
mlir_str = module
|
||||
if frontend in ["tensorflow", "tf", "mhlo", "tflite"]:
|
||||
if frontend in ["tensorflow", "tf", "mhlo", "stablehlo", "tflite"]:
|
||||
mlir_str = module.decode("utf-8")
|
||||
elif frontend in ["pytorch", "torch"]:
|
||||
mlir_str = module.operation.get_asm()
|
||||
|
||||
@@ -117,7 +117,8 @@ def get_extensions(triple):
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("VK_NV_cooperative_matrix")
|
||||
|
||||
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
|
||||
ext.append("VK_KHR_shader_integer_dot_product")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
|
||||
@@ -133,9 +134,9 @@ def get_vendor(triple):
|
||||
return "Apple"
|
||||
if arch in ["arc", "UHD"]:
|
||||
return "Intel"
|
||||
if arch in ["turing", "ampere"]:
|
||||
if arch in ["turing", "ampere", "pascal"]:
|
||||
return "NVIDIA"
|
||||
if arch == "ardeno":
|
||||
if arch == "adreno":
|
||||
return "Qualcomm"
|
||||
if arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
@@ -151,7 +152,7 @@ def get_device_type(triple):
|
||||
return "Unknown"
|
||||
if arch == "cpu":
|
||||
return "CPU"
|
||||
if arch in ["turing", "ampere", "arc"]:
|
||||
if arch in ["turing", "ampere", "arc", "pascal"]:
|
||||
return "DiscreteGPU"
|
||||
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]:
|
||||
if product == "ivega10":
|
||||
@@ -228,6 +229,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
@@ -236,12 +238,12 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
|
||||
]
|
||||
|
||||
if product == "rx5700xt":
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["storagePushConstant8"] = False
|
||||
@@ -274,7 +276,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
@@ -305,6 +307,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = False
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
@@ -367,6 +370,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = False
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
@@ -389,6 +393,40 @@ def get_vulkan_target_capabilities(triple):
|
||||
"ShuffleRelative",
|
||||
]
|
||||
|
||||
elif arch in ["pascal"]:
|
||||
cap["maxComputeSharedMemorySize"] = 49152
|
||||
cap["maxComputeWorkGroupInvocations"] = 1536
|
||||
cap["maxComputeWorkGroupSize"] = [1536, 1024, 64]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = False
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch in ["ampere", "turing"]:
|
||||
cap["maxComputeSharedMemorySize"] = 49152
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
@@ -413,6 +451,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
|
||||
@@ -21,7 +21,7 @@ from sys import platform
|
||||
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
def get_vulkan_device_name():
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
|
||||
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
@@ -31,8 +31,8 @@ def get_vulkan_device_name():
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing first one: {vulkaninfo_list[0]}")
|
||||
return vulkaninfo_list[0]
|
||||
print(f"Choosing device: {vulkaninfo_list[device_num]}")
|
||||
return vulkaninfo_list[device_num]
|
||||
|
||||
|
||||
def get_os_name():
|
||||
@@ -114,19 +114,24 @@ def get_vulkan_target_triple(device_name):
|
||||
# Intel Targets
|
||||
elif any(x in device_name for x in ("A770", "A750")):
|
||||
triple = f"arc-770-{system_os}"
|
||||
|
||||
# Adreno Targets
|
||||
elif all(x in device_name for x in ("Adreno", "740")):
|
||||
triple = f"adreno-a740-{system_os}"
|
||||
|
||||
else:
|
||||
triple = None
|
||||
return triple
|
||||
|
||||
|
||||
def get_vulkan_triple_flag(device_name="", extra_args=[]):
|
||||
def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
for flag in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in flag:
|
||||
print(f"Using target triple {flag.split('=')[1]}")
|
||||
return None
|
||||
|
||||
if device_name == "" or device_name == [] or device_name is None:
|
||||
vulkan_device = get_vulkan_device_name()
|
||||
vulkan_device = get_vulkan_device_name(device_num=device_num)
|
||||
else:
|
||||
vulkan_device = device_name
|
||||
triple = get_vulkan_target_triple(vulkan_device)
|
||||
@@ -144,7 +149,7 @@ def get_vulkan_triple_flag(device_name="", extra_args=[]):
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_vulkan_args(extra_args=[]):
|
||||
def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
res_vulkan_flag = []
|
||||
@@ -156,7 +161,9 @@ def get_iree_vulkan_args(extra_args=[]):
|
||||
break
|
||||
|
||||
if vulkan_triple_flag is None:
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)
|
||||
|
||||
@@ -30,8 +30,8 @@ import os
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
|
||||
import iree.compiler._mlir_libs
|
||||
from iree.compiler import ir
|
||||
from iree.compiler.transforms import ireec as ireec_trans
|
||||
|
||||
|
||||
def model_annotation(
|
||||
@@ -311,11 +311,18 @@ def add_attributes(op: ir.Operation, config: List[Dict]):
|
||||
split_k = config["split_k"]
|
||||
elif "SPIRV" in config["pipeline"]:
|
||||
pipeline = config["pipeline"]
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"],
|
||||
config["parallel_tile_sizes"],
|
||||
config["reduction_tile_sizes"],
|
||||
]
|
||||
if pipeline == "SPIRVMatmulPromoteVectorize":
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"]
|
||||
+ [config["reduction_tile_sizes"][-1]],
|
||||
]
|
||||
else:
|
||||
tile_sizes = [
|
||||
config["work_group_tile_sizes"],
|
||||
config["parallel_tile_sizes"],
|
||||
config["reduction_tile_sizes"],
|
||||
]
|
||||
|
||||
workgroup_size = config["work_group_sizes"]
|
||||
if "vector_tile_sizes" in config.keys():
|
||||
tile_sizes += [config["vector_tile_sizes"]]
|
||||
@@ -409,7 +416,6 @@ def shape_list_to_string(input):
|
||||
|
||||
def create_context() -> ir.Context:
|
||||
context = ir.Context()
|
||||
ireec_trans.register_all_dialects(context)
|
||||
context.allow_unregistered_dialects = True
|
||||
return context
|
||||
|
||||
|
||||
@@ -61,6 +61,8 @@ def download_public_file(
|
||||
continue
|
||||
|
||||
destination_filename = os.path.join(destination_folder_name, blob_name)
|
||||
if os.path.isdir(destination_filename):
|
||||
continue
|
||||
with open(destination_filename, "wb") as f:
|
||||
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
|
||||
storage_client.download_blob_to_file(blob, file_obj)
|
||||
@@ -196,7 +198,7 @@ def download_model(
|
||||
tank_url=None,
|
||||
frontend=None,
|
||||
tuned=None,
|
||||
import_args=None,
|
||||
import_args={"batch_size": 1},
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
@@ -210,6 +212,9 @@ def download_model(
|
||||
+ "_BS"
|
||||
+ str(import_args["batch_size"])
|
||||
)
|
||||
elif any(model in model_name for model in ["clip", "unet", "vae"]):
|
||||
# TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation.
|
||||
model_dir_name = model_name
|
||||
else:
|
||||
model_dir_name = model_name + "_" + frontend
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
@@ -270,6 +275,9 @@ def download_model(
|
||||
tuned_str = "" if tuned is None else "_" + tuned
|
||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
print(
|
||||
f"Verifying that model artifacts were downloaded successfully to {filename}..."
|
||||
)
|
||||
if not os.path.exists(filename):
|
||||
from tank.generate_sharktank import gen_shark_files
|
||||
|
||||
|
||||
38
shark/shark_generate_model_config.py
Normal file
38
shark/shark_generate_model_config.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class GenerateConfigFile:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
num_sharding_stages: int,
|
||||
sharding_stages_id: list[str] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.num_sharding_stages = num_sharding_stages
|
||||
self.sharding_stages_id = sharding_stages_id
|
||||
assert self.num_sharding_stages == len(
|
||||
self.sharding_stages_id
|
||||
), "Number of sharding stages should be equal to the list of their ID"
|
||||
|
||||
def generate_json(self):
|
||||
model_dictionary = dict()
|
||||
|
||||
for name, m in self.model.named_modules():
|
||||
if name == "":
|
||||
continue
|
||||
|
||||
# Remove non-leaf nodes from the config as they aren't an operation
|
||||
substring_before_final_period = name.split(".")[:-1]
|
||||
substring_before_final_period = ".".join(
|
||||
substring_before_final_period
|
||||
)
|
||||
if substring_before_final_period in model_dictionary:
|
||||
del model_dictionary[substring_before_final_period]
|
||||
|
||||
layer_dict = {n: "None" for n in self.sharding_stages_id}
|
||||
model_dictionary[name] = layer_dict
|
||||
|
||||
with open("model_config.json", "w") as outfile:
|
||||
json.dump(model_dictionary, outfile)
|
||||
@@ -81,7 +81,7 @@ class SharkImporter:
|
||||
|
||||
# NOTE: The default function for torch is "forward" and tf-lite is "main".
|
||||
|
||||
def _torch_mlir(self, is_dynamic, tracing_required):
|
||||
def _torch_mlir(self, is_dynamic, tracing_required, mlir_type):
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module
|
||||
|
||||
return get_torch_mlir_module(
|
||||
@@ -90,6 +90,7 @@ class SharkImporter:
|
||||
is_dynamic,
|
||||
tracing_required,
|
||||
self.return_str,
|
||||
mlir_type,
|
||||
)
|
||||
|
||||
def _tf_mlir(self, func_name, save_dir="."):
|
||||
@@ -120,6 +121,7 @@ class SharkImporter:
|
||||
tracing_required=False,
|
||||
func_name="forward",
|
||||
save_dir="./shark_tmp/",
|
||||
mlir_type="linalg",
|
||||
):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
if self.inputs == None:
|
||||
@@ -127,7 +129,10 @@ class SharkImporter:
|
||||
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module"
|
||||
)
|
||||
sys.exit(1)
|
||||
return self._torch_mlir(is_dynamic, tracing_required), func_name
|
||||
return (
|
||||
self._torch_mlir(is_dynamic, tracing_required, mlir_type),
|
||||
func_name,
|
||||
)
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
return self._tf_mlir(func_name, save_dir), func_name
|
||||
if self.frontend in ["tflite", "tf-lite"]:
|
||||
@@ -143,14 +148,23 @@ class SharkImporter:
|
||||
|
||||
# Saves `function_name.npy`, `inputs.npz`, `golden_out.npz` and `model_name.mlir` in the directory `dir`.
|
||||
def save_data(
|
||||
self, dir, model_name, mlir_data, func_name, inputs, outputs
|
||||
self,
|
||||
dir,
|
||||
model_name,
|
||||
mlir_data,
|
||||
func_name,
|
||||
inputs,
|
||||
outputs,
|
||||
mlir_type="linalg",
|
||||
):
|
||||
import numpy as np
|
||||
|
||||
inputs_name = "inputs.npz"
|
||||
outputs_name = "golden_out.npz"
|
||||
func_file_name = "function_name"
|
||||
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
|
||||
model_name_mlir = (
|
||||
model_name + "_" + self.frontend + "_" + mlir_type + ".mlir"
|
||||
)
|
||||
print(f"saving {model_name_mlir} to {dir}")
|
||||
try:
|
||||
inputs = [x.cpu().detach() for x in inputs]
|
||||
@@ -186,19 +200,23 @@ class SharkImporter:
|
||||
dir=tempfile.gettempdir(),
|
||||
model_name="model",
|
||||
golden_values=None,
|
||||
mlir_type="linalg",
|
||||
):
|
||||
if self.inputs == None:
|
||||
print(
|
||||
f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir."
|
||||
)
|
||||
sys.exit(1)
|
||||
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
|
||||
model_name_mlir = (
|
||||
model_name + "_" + self.frontend + "_" + mlir_type + ".mlir"
|
||||
)
|
||||
artifact_path = os.path.join(dir, model_name_mlir)
|
||||
imported_mlir = self.import_mlir(
|
||||
is_dynamic,
|
||||
tracing_required,
|
||||
func_name,
|
||||
save_dir=artifact_path,
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
|
||||
# TODO: Check for multiple outputs.
|
||||
@@ -224,6 +242,7 @@ class SharkImporter:
|
||||
imported_mlir[1],
|
||||
self.inputs,
|
||||
golden_out,
|
||||
mlir_type,
|
||||
)
|
||||
return (
|
||||
imported_mlir,
|
||||
@@ -293,6 +312,47 @@ def get_f16_inputs(inputs, is_f16, f16_input_mask):
|
||||
return tuple(f16_masked_inputs)
|
||||
|
||||
|
||||
# Upcasts the block/list of ops.
|
||||
def add_upcast(fx_g):
|
||||
import torch
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.target in [torch.ops.aten.mul]:
|
||||
# This is a very strict check.
|
||||
if hasattr(node.args[1], "target"):
|
||||
if (
|
||||
node.args[1].target in [torch.ops.aten.rsqrt]
|
||||
and node.args[1].args[0].target in [torch.ops.aten.add]
|
||||
and node.args[1].args[0].args[0].target
|
||||
in [torch.ops.aten.mean]
|
||||
and node.args[1].args[0].args[0].args[0].target
|
||||
in [torch.ops.aten.pow]
|
||||
):
|
||||
print("found an upcasting block let's upcast it.")
|
||||
pow_node = node.args[1].args[0].args[0].args[0]
|
||||
mul_node = node
|
||||
with fx_g.graph.inserting_before(pow_node):
|
||||
lhs = pow_node.args[0]
|
||||
upcast_lhs = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(lhs,),
|
||||
kwargs={"dtype": torch.float32},
|
||||
)
|
||||
pow_node.args = (upcast_lhs, pow_node.args[1])
|
||||
with fx_g.graph.inserting_before(mul_node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(mul_node,),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
mul_node.append(new_node)
|
||||
mul_node.replace_all_uses_with(new_node)
|
||||
new_node.args = (mul_node,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
def transform_fx(fx_g):
|
||||
import torch
|
||||
|
||||
@@ -301,14 +361,49 @@ def transform_fx(fx_g):
|
||||
"device": torch.device(type="cpu"),
|
||||
"pin_memory": False,
|
||||
}
|
||||
kwargs_dict1 = {
|
||||
"dtype": torch.float16,
|
||||
}
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.zeros,
|
||||
torch.ops.aten.zeros_like,
|
||||
]:
|
||||
node.kwargs = kwargs_dict
|
||||
if node.kwargs.get("dtype") == torch.float32:
|
||||
node.kwargs = kwargs_dict
|
||||
|
||||
# Vicuna
|
||||
if node.target in [
|
||||
torch.ops.aten._to_copy,
|
||||
]:
|
||||
if node.kwargs.get("dtype") == torch.float32:
|
||||
node.kwargs = kwargs_dict1
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten.masked_fill,
|
||||
]:
|
||||
if node.args[2] > torch.finfo(torch.half).max:
|
||||
max_val = torch.finfo(torch.half).max
|
||||
node.args = (node.args[0], node.args[1], max_val)
|
||||
elif node.args[2] < torch.finfo(torch.half).min:
|
||||
min_val = torch.finfo(torch.half).min
|
||||
node.args = (node.args[0], node.args[1], min_val)
|
||||
|
||||
if node.target in [
|
||||
torch.ops.aten.full,
|
||||
]:
|
||||
if node.args[1] > torch.finfo(torch.half).max:
|
||||
max_val = torch.finfo(torch.half).max
|
||||
node.args = (node.args[0], max_val)
|
||||
node.kwargs = kwargs_dict
|
||||
elif node.args[1] < torch.finfo(torch.half).min:
|
||||
min_val = torch.finfo(torch.half).min
|
||||
node.args = (node.args[0], min_val)
|
||||
node.kwargs = kwargs_dict
|
||||
|
||||
# Inputs and outputs of aten.var.mean should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.var_mean]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
@@ -318,6 +413,7 @@ def transform_fx(fx_g):
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node, node.args[1])
|
||||
|
||||
if node.name.startswith("getitem"):
|
||||
with fx_g.graph.inserting_before(node):
|
||||
if node.args[0].target in [torch.ops.aten.var_mean]:
|
||||
@@ -330,6 +426,7 @@ def transform_fx(fx_g):
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
@@ -341,6 +438,14 @@ def transform_fx(fx_g):
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
# Required for cuda debugging.
|
||||
# for node in fx_g.graph.nodes:
|
||||
# if node.op == "call_function":
|
||||
# if node.kwargs.get("device") == torch.device(type="cpu"):
|
||||
# new_kwargs = node.kwargs.copy()
|
||||
# new_kwargs["device"] = torch.device(type="cuda")
|
||||
# node.kwargs = new_kwargs
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
@@ -392,6 +497,9 @@ def import_with_fx(
|
||||
return_str=False,
|
||||
save_dir=tempfile.gettempdir(),
|
||||
model_name="model",
|
||||
mlir_type="linalg",
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
):
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@@ -418,6 +526,8 @@ def import_with_fx(
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
torch.ops.aten.native_layer_norm,
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
]
|
||||
),
|
||||
)(*inputs)
|
||||
@@ -441,6 +551,8 @@ def import_with_fx(
|
||||
if is_f16:
|
||||
fx_g = fx_g.half()
|
||||
transform_fx(fx_g)
|
||||
# TODO: Have to make it more generic.
|
||||
add_upcast(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if training:
|
||||
@@ -448,6 +560,9 @@ def import_with_fx(
|
||||
inputs = flatten_training_input(inputs)
|
||||
|
||||
ts_graph = torch.jit.script(fx_g)
|
||||
if mlir_type == "torchscript":
|
||||
return ts_graph
|
||||
|
||||
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)
|
||||
mlir_importer = SharkImporter(
|
||||
ts_graph,
|
||||
@@ -458,7 +573,12 @@ def import_with_fx(
|
||||
|
||||
if debug: # and not is_f16:
|
||||
(mlir_module, func_name), _, _ = mlir_importer.import_debug(
|
||||
dir=save_dir, model_name=model_name, golden_values=golden_values
|
||||
dir=save_dir,
|
||||
model_name=model_name,
|
||||
golden_values=golden_values,
|
||||
mlir_type=mlir_type,
|
||||
is_dynamic=is_dynamic,
|
||||
tracing_required=tracing_required,
|
||||
)
|
||||
return mlir_module, func_name
|
||||
|
||||
|
||||
@@ -25,7 +25,14 @@ import sys
|
||||
|
||||
|
||||
# supported dialects by the shark-runtime.
|
||||
supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite", "tm_tensor"}
|
||||
supported_dialects = {
|
||||
"linalg",
|
||||
"auto",
|
||||
"stablehlo",
|
||||
"tosa",
|
||||
"tf-lite",
|
||||
"tm_tensor",
|
||||
}
|
||||
|
||||
|
||||
class SharkRunner:
|
||||
|
||||
@@ -59,6 +59,7 @@ class SharkTrainer:
|
||||
"torch",
|
||||
"tensorflow",
|
||||
"tf",
|
||||
"stablehlo",
|
||||
"mhlo",
|
||||
"linalg",
|
||||
"tosa",
|
||||
@@ -84,7 +85,7 @@ class SharkTrainer:
|
||||
"tm_tensor",
|
||||
extra_args=extra_args,
|
||||
)
|
||||
elif self.frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
elif self.frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
self.shark_runner = SharkRunner(
|
||||
self.model,
|
||||
self.input,
|
||||
|
||||
@@ -19,6 +19,12 @@ import tempfile
|
||||
from shark.parser import shark_args
|
||||
import io
|
||||
|
||||
mlir_type_mapping_dict = {
|
||||
"linalg": torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
"stablehlo": torch_mlir.OutputType.STABLEHLO,
|
||||
"tosa": torch_mlir.OutputType.TOSA,
|
||||
}
|
||||
|
||||
|
||||
def get_module_name_for_asm_dump(module):
|
||||
"""Gets a name suitable for an assembly dump.
|
||||
@@ -57,6 +63,7 @@ def get_torch_mlir_module(
|
||||
dynamic: bool,
|
||||
jit_trace: bool,
|
||||
return_str: bool = False,
|
||||
mlir_type: str = "linalg",
|
||||
):
|
||||
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
|
||||
ignore_traced_shapes = False
|
||||
@@ -70,10 +77,11 @@ def get_torch_mlir_module(
|
||||
mlir_module = torch_mlir.compile(
|
||||
module,
|
||||
input,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
output_type=mlir_type_mapping_dict[mlir_type],
|
||||
use_tracing=jit_trace,
|
||||
ignore_traced_shapes=ignore_traced_shapes,
|
||||
)
|
||||
|
||||
if return_str:
|
||||
return mlir_module.operation.get_asm()
|
||||
bytecode_stream = io.BytesIO()
|
||||
|
||||
@@ -1,46 +1,47 @@
|
||||
resnet50,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
albert-base-v2,mhlo,tf,1e-2,1e-2,default,None,False,False,False,"",""
|
||||
roberta-base,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
bert-base-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","enabled_windows"
|
||||
camembert-base,mhlo,tf,1e-2,1e-3,default,None,True,True,True,"",""
|
||||
dbmdz/convbert-base-turkish-cased,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/iree-org/iree/issues/9971",""
|
||||
distilbert-base-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/convnext-tiny-224,mhlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342","macos"
|
||||
funnel-transformer/small,mhlo,tf,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/201",""
|
||||
google/electra-small-discriminator,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"Fails during iree-compile",""
|
||||
google/vit-base-patch16-224,mhlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,False,"",""
|
||||
microsoft/MiniLM-L12-H384-uncased,mhlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
|
||||
microsoft/layoutlm-base-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
microsoft/mpnet-base,mhlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
|
||||
resnet50,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
albert-base-v2,stablehlo,tf,1e-2,1e-2,default,None,False,False,False,"",""
|
||||
roberta-base,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
bert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","enabled_windows"
|
||||
camembert-base,stablehlo,tf,1e-2,1e-3,default,None,True,True,True,"",""
|
||||
dbmdz/convbert-base-turkish-cased,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/iree-org/iree/issues/9971",""
|
||||
distilbert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/convnext-tiny-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342","macos"
|
||||
funnel-transformer/small,stablehlo,tf,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/201",""
|
||||
google/electra-small-discriminator,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"Fails during iree-compile","macos"
|
||||
google/vit-base-patch16-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,False,"",""
|
||||
microsoft/MiniLM-L12-H384-uncased,stablehlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
|
||||
microsoft/layoutlm-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
microsoft/mpnet-base,stablehlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
|
||||
albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with aten.tanh in torch-mlir",""
|
||||
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,False,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-large-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-large-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344",""
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/388","macos"
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/343","macos"
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,True,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
|
||||
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
|
||||
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,False,True,"",""
|
||||
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
|
||||
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
efficientnet-v2-s,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243",""
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"",""
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported",""
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b0,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
gpt2,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
|
||||
t5-base,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
|
||||
t5-large,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
stabilityai/stable-diffusion-2-1-base,linalg,torch,1e-3,1e-3,default,None,True,False,False,"","macos"
|
||||
|
||||
|
@@ -75,7 +75,7 @@ if __name__ == "__main__":
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
input_type="auto",
|
||||
)
|
||||
# flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ if __name__ == "__main__":
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
input_type="auto",
|
||||
)
|
||||
|
||||
# Save module as MLIR file in a directory
|
||||
|
||||
@@ -96,7 +96,7 @@ if __name__ == "__main__":
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
input_type="auto",
|
||||
)
|
||||
# flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ if __name__ == "__main__":
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
input_type="auto",
|
||||
)
|
||||
# flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
|
||||
|
||||
143
tank/examples/opt/opt_causallm.py
Normal file
143
tank/examples/opt/opt_causallm.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
from shark.iree_utils._common import (
|
||||
check_device_drivers,
|
||||
device_driver_info,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
OPT_MODEL = "opt-1.3b"
|
||||
OPT_FS_NAME = "opt-1_3b"
|
||||
MAX_SEQUENCE_LENGTH = 30
|
||||
MAX_NEW_TOKENS = 20
|
||||
|
||||
|
||||
def create_module(model_name, tokenizer, device):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
|
||||
opt_base_model.eval()
|
||||
opt_model = OPTForCausalLMModel(opt_base_model)
|
||||
encoded_inputs = tokenizer(
|
||||
"What is the meaning of life?",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
# np.save("model_inputs_0.npy", inputs[0])
|
||||
# np.save("model_inputs_1.npy", inputs[1])
|
||||
|
||||
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
|
||||
if os.path.isfile(mlir_path):
|
||||
with open(mlir_path, "r") as f:
|
||||
model_mlir = f.read()
|
||||
print(f"Loaded .mlir from {mlir_path}")
|
||||
else:
|
||||
(model_mlir, func_name) = import_with_fx(
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=OPT_FS_NAME,
|
||||
return_str=True,
|
||||
)
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=False,
|
||||
)
|
||||
|
||||
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
|
||||
shark_module.save_module(module_name=vmfb_name)
|
||||
vmfb_path = vmfb_name + ".vmfb"
|
||||
return vmfb_path
|
||||
|
||||
|
||||
def shouldStop(tokens):
|
||||
stop_ids = [50278, 50279, 50277, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def generate_new_token(shark_model, tokenizer, new_text):
|
||||
model_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
model_inputs["input_ids"],
|
||||
model_inputs["attention_mask"],
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
output = shark_model("forward", inputs)
|
||||
output = torch.FloatTensor(output[0])
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[int(sum_attentionmask) - 1]
|
||||
detok = tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/" + OPT_MODEL, use_fast=False
|
||||
)
|
||||
vmfb_path = (
|
||||
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-sync.vmfb"
|
||||
)
|
||||
opt_shark_module = SharkInference(mlir_module=None, device="cpu-sync")
|
||||
if os.path.isfile(vmfb_path):
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-sync")
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence to complete:")
|
||||
new_text_init = new_text
|
||||
words_list = []
|
||||
|
||||
for i in range(MAX_NEW_TOKENS):
|
||||
generated_token_op = generate_new_token(
|
||||
opt_shark_module, tokenizer, new_text
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
if stop_generation:
|
||||
break
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text = new_text + detok
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
break
|
||||
200
tank/examples/opt/opt_causallm_torch_test.py
Normal file
200
tank/examples/opt/opt_causallm_torch_test.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import unittest
|
||||
import os
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
OPT_MODEL = "facebook/opt-1.3b"
|
||||
OPT_FS_NAME = "opt-1_3b"
|
||||
OPT_MODEL_66B = "facebook/opt-66b"
|
||||
|
||||
|
||||
class OPTModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
benchmark=False,
|
||||
):
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device, model_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
opt_model = OPTForCausalLM.from_pretrained(
|
||||
model_name, return_dict=False
|
||||
)
|
||||
opt_model.eval()
|
||||
|
||||
model_inputs = tokenizer(
|
||||
"The meaning of life is",
|
||||
padding="max_length",
|
||||
max_length=30,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
model_inputs.data["input_ids"],
|
||||
model_inputs.data["attention_mask"],
|
||||
)
|
||||
act_out = opt_model(
|
||||
inputs[0], attention_mask=inputs[1], return_dict=False
|
||||
)[0]
|
||||
(
|
||||
mlir_module,
|
||||
func_name,
|
||||
) = import_with_fx(
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=OPT_FS_NAME,
|
||||
)
|
||||
del opt_model
|
||||
opt_filename = f"./{OPT_FS_NAME}_causallm_30_torch_{device}"
|
||||
mlir_path = os.path.join(opt_filename, ".mlir")
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(mlir_module)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=self.benchmark,
|
||||
)
|
||||
|
||||
shark_module.compile()
|
||||
results = shark_module("forward", inputs)
|
||||
print(
|
||||
"SHARK logits have shape: ",
|
||||
str(results[0].shape) + " : " + str(results[0]),
|
||||
)
|
||||
print(
|
||||
"PyTorch logits have shape: "
|
||||
+ str(act_out[0].shape)
|
||||
+ " : "
|
||||
+ str(act_out[0])
|
||||
)
|
||||
# exp_out = tokenizer.decode(act_out[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
# shark_out = tokenizer.decode(results[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
np.testing.assert_allclose(act_out[0].detach(), results[0])
|
||||
|
||||
if self.benchmark:
|
||||
shark_module.shark_runner.benchmark_all_csv(
|
||||
inputs,
|
||||
"opt",
|
||||
dynamic,
|
||||
device,
|
||||
"torch",
|
||||
)
|
||||
|
||||
|
||||
class OPTModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = OPTModuleTester(self)
|
||||
self.module_tester.save_mlir = False
|
||||
self.module_tester.save_vmfb = False
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
def test_1_3b_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
def test_1_3b_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
)
|
||||
def test_1_3b_static_cuda(self):
|
||||
dynamic = False
|
||||
device = "cuda"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
)
|
||||
def test_1_3b_dynamic_cuda(self):
|
||||
dynamic = True
|
||||
device = "cuda"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_1_3b_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_1_3b_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device, OPT_MODEL)
|
||||
|
||||
# def test_66b_static_cpu(self):
|
||||
# dynamic = False
|
||||
# device = "cpu"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# def test_66b_dynamic_cpu(self):
|
||||
# dynamic = True
|
||||
# device = "cpu"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
# )
|
||||
# def test_66b_static_cuda(self):
|
||||
# dynamic = False
|
||||
# device = "cuda"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
# )
|
||||
# def test_66b_dynamic_cuda(self):
|
||||
# dynamic = True
|
||||
# device = "cuda"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
# )
|
||||
# def test_66b_static_vulkan(self):
|
||||
# dynamic = False
|
||||
# device = "vulkan"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
# )
|
||||
# def test_66b_dynamic_vulkan(self):
|
||||
# dynamic = True
|
||||
# device = "vulkan"
|
||||
# self.module_tester.create_and_check_module(
|
||||
# dynamic, device, OPT_MODEL_66B
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -2,11 +2,11 @@ import unittest
|
||||
|
||||
import pytest
|
||||
import torch_mlir
|
||||
from hacked_hf_opt import OPTModel
|
||||
from shark_hf_opt import OPTModel
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from tank.model_utils import compare_tensors
|
||||
from transformers import GPT2Tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
OPT_MODEL = "facebook/opt-350m"
|
||||
OPT_MODEL_66B = "facebook/opt-66b"
|
||||
@@ -20,7 +20,7 @@ class OPTModuleTester:
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device, model_name):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
# config = OPTConfig()
|
||||
# opt_model = OPTModel(config)
|
||||
opt_model = OPTModel.from_pretrained(model_name)
|
||||
@@ -56,13 +56,12 @@ class OPTModuleTester:
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
func_name,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=self.benchmark,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input_ids, attention_mask))
|
||||
results = shark_module("forward", (input_ids, attention_mask))
|
||||
assert compare_tensors(act_out, results)
|
||||
|
||||
if self.benchmark:
|
||||
|
||||
47
tank/examples/opt/shark_hf_base_opt.py
Normal file
47
tank/examples/opt/shark_hf_base_opt.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
|
||||
model_name = "facebook/opt-1.3b"
|
||||
base_model = OPTForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
|
||||
model = OPTForCausalLMModel(base_model)
|
||||
|
||||
prompt = "What is the meaning of life?"
|
||||
model_inputs = tokenizer(prompt, return_tensors="pt")
|
||||
inputs = (
|
||||
model_inputs["input_ids"],
|
||||
model_inputs["attention_mask"],
|
||||
)
|
||||
|
||||
(
|
||||
mlir_module,
|
||||
func_name,
|
||||
) = import_with_fx(
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
debug=True,
|
||||
model_name=model_name.split("/")[1],
|
||||
save_dir=".",
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device="cpu-sync",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile()
|
||||
# Generated logits.
|
||||
logits = shark_module("forward", inputs=inputs)
|
||||
print("SHARK module returns logits:")
|
||||
print(logits[0])
|
||||
|
||||
hf_logits = base_model.forward(inputs[0], inputs[1], return_dict=False)[0]
|
||||
|
||||
print("PyTorch baseline returns logits:")
|
||||
print(hf_logits)
|
||||
15
tank/examples/opt/shark_opt_wrapper.py
Normal file
15
tank/examples/opt/shark_opt_wrapper.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
class OPTForCausalLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
@@ -279,7 +279,6 @@ class OPTAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (
|
||||
bsz * self.num_heads,
|
||||
tgt_len,
|
||||
@@ -314,6 +313,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
num_heads=config.num_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
bias=config.enable_bias,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
self.dropout = config.dropout
|
||||
@@ -321,10 +321,16 @@ class OPTDecoderLayer(nn.Module):
|
||||
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine,
|
||||
)
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
|
||||
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -450,7 +456,14 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
else:
|
||||
self.project_in = None
|
||||
|
||||
self.layer_norm = None
|
||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
config.hidden_size,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine,
|
||||
)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
@@ -647,6 +660,9 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
|
||||
@@ -832,7 +848,10 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if isinstance(outputs[1:], tuple):
|
||||
output = (logits,) + outputs[1:]
|
||||
else:
|
||||
output = (logits, outputs[1:])
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
88
tank/examples/opt/tune_opt.py
Normal file
88
tank/examples/opt/tune_opt.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shark_tuner.codegen_tuner import SharkCodegenTuner
|
||||
from shark_tuner.iree_utils import (
|
||||
dump_dispatches,
|
||||
create_context,
|
||||
export_module_to_mlir_file,
|
||||
)
|
||||
from shark_tuner.model_annotation import model_annotation
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
from shark.shark_importer import import_with_fx
|
||||
|
||||
NUM_ITERS = 400
|
||||
MODEL_NAME = "facebook/opt-1.3b"
|
||||
MODEL_FNAME = "opt-1_3b-causallm"
|
||||
|
||||
def load_mlir_module():
|
||||
hf_model = OPTForCausalLM.from_pretrained(MODEL_NAME)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
|
||||
|
||||
opt_model = OPTForCausalLMModel(hf_model)
|
||||
|
||||
prompt = "What is the meaning of life?"
|
||||
model_inputs = tokenizer(prompt, return_tensors="pt")
|
||||
inputs = (
|
||||
model_inputs["input_ids"],
|
||||
model_inputs["attention_mask"],
|
||||
)
|
||||
|
||||
(
|
||||
mlir_module,
|
||||
func_name,
|
||||
) = import_with_fx(
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=MODEL_NAME.split("/")[1],
|
||||
)
|
||||
return mlir_module, model_name
|
||||
|
||||
|
||||
def main():
|
||||
#mlir_module, model_name = load_mlir_module()
|
||||
|
||||
# Get device and device specific arguments
|
||||
device = "cpu"
|
||||
|
||||
# Dump model dispatches
|
||||
model_name = MODEL_NAME
|
||||
#generates_dir = "."
|
||||
#if not os.path.exists(generates_dir):
|
||||
# os.makedirs(generates_dir)
|
||||
#dump_mlir = generates_dir / "temp.mlir"
|
||||
dispatch_dir = f"./{MODEL_FNAME}_{device}_dispatches"
|
||||
#export_module_to_mlir_file(mlir_module, dump_mlir)
|
||||
#dump_dispatches(
|
||||
# dump_mlir,
|
||||
# device,
|
||||
# dispatch_dir,
|
||||
#)
|
||||
|
||||
# Tune each dispatch
|
||||
dtype = "f32"
|
||||
config_filename = f"{MODEL_FNAME}_{device}_configs.json"
|
||||
for f_path in os.listdir(dispatch_dir):
|
||||
if not f_path.endswith(".mlir"):
|
||||
continue
|
||||
|
||||
model_dir = os.path.join(dispatch_dir, f_path)
|
||||
|
||||
tuner = SharkCodegenTuner(
|
||||
model_dir,
|
||||
device,
|
||||
"random",
|
||||
NUM_ITERS,
|
||||
".",
|
||||
dtype,
|
||||
search_op="all",
|
||||
batch_size=1,
|
||||
config_filename=config_filename,
|
||||
use_dispatch=True,
|
||||
)
|
||||
tuner.tune()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -37,10 +37,12 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
from tank.model_utils import (
|
||||
get_hf_model,
|
||||
get_hf_seq2seq_model,
|
||||
get_hf_causallm_model,
|
||||
get_vision_model,
|
||||
get_hf_img_cls_model,
|
||||
get_fp16_model,
|
||||
)
|
||||
from shark.shark_importer import import_with_fx
|
||||
|
||||
with open(torch_model_list) as csvfile:
|
||||
torch_reader = csv.reader(csvfile, delimiter=",")
|
||||
@@ -50,6 +52,8 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
tracing_required = row[1]
|
||||
model_type = row[2]
|
||||
is_dynamic = row[3]
|
||||
mlir_type = row[4]
|
||||
is_decompose = row[5]
|
||||
|
||||
tracing_required = False if tracing_required == "False" else True
|
||||
is_dynamic = False if is_dynamic == "False" else True
|
||||
@@ -91,6 +95,10 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
model, input, _ = get_hf_seq2seq_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf_causallm":
|
||||
model, input, _ = get_hf_causallm_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf_img_cls":
|
||||
model, input, _ = get_hf_img_cls_model(
|
||||
torch_model_name, import_args
|
||||
@@ -111,25 +119,45 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
)
|
||||
os.makedirs(torch_model_dir, exist_ok=True)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name,
|
||||
)
|
||||
# Generate torch dynamic models.
|
||||
if is_dynamic:
|
||||
if is_decompose:
|
||||
# Add decomposition to some torch ops
|
||||
# TODO add op whitelist/blacklist
|
||||
import_with_fx(
|
||||
model,
|
||||
(input,),
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
debug=True,
|
||||
training=False,
|
||||
return_str=False,
|
||||
save_dir=torch_model_dir,
|
||||
model_name=torch_model_name,
|
||||
mlir_type=mlir_type,
|
||||
is_dynamic=False,
|
||||
tracing_required=tracing_required,
|
||||
)
|
||||
else:
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=True,
|
||||
is_dynamic=False,
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name + "_dynamic",
|
||||
model_name=torch_model_name,
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
# Generate torch dynamic models.
|
||||
if is_dynamic:
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=True,
|
||||
tracing_required=tracing_required,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name + "_dynamic",
|
||||
mlir_type=mlir_type,
|
||||
)
|
||||
|
||||
|
||||
def save_tf_model(tf_model_list, local_tank_cache, import_args):
|
||||
|
||||
@@ -176,6 +176,43 @@ def get_hf_seq2seq_model(name, import_args):
|
||||
return m, test_input, actual_out
|
||||
|
||||
|
||||
##################### Hugging Face CausalLM Models ###################################
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
|
||||
def prepare_sentence_tokens(hf_model: str, sentence: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model)
|
||||
return torch.tensor([tokenizer.encode(sentence)])
|
||||
|
||||
|
||||
class HFCausalLM(torch.nn.Module):
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__()
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, # The pretrained model name.
|
||||
# The number of output labels--2 for binary classification.
|
||||
num_labels=2,
|
||||
# Whether the model returns attentions weights.
|
||||
output_attentions=False,
|
||||
# Whether the model returns all hidden-states.
|
||||
output_hidden_states=False,
|
||||
torchscript=True,
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
def get_hf_causallm_model(name, import_args):
|
||||
m = HFCausalLM(name)
|
||||
test_input = prepare_sentence_tokens(
|
||||
name, "this project is very interesting"
|
||||
)
|
||||
actual_out = m.forward(*test_input)
|
||||
return m, test_input, actual_out
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
##################### Torch Vision Models ###################################
|
||||
|
||||
@@ -64,6 +64,7 @@ def get_valid_test_params():
|
||||
device
|
||||
for device in get_supported_device_list()
|
||||
if not check_device_drivers(device)
|
||||
and device not in ["cpu-sync", "cpu-task"]
|
||||
]
|
||||
dynamic_list = (True, False)
|
||||
# TODO: This is soooo ugly, but for some reason creating the dict at runtime
|
||||
@@ -92,6 +93,8 @@ def get_valid_test_params():
|
||||
def is_valid_case(test_params):
|
||||
if test_params[0] == True and test_params[2]["framework"] == "tf":
|
||||
return False
|
||||
if test_params[2]["framework"] == "tf":
|
||||
return False
|
||||
elif "fp16" in test_params[2]["model_name"] and test_params[1] != "cuda":
|
||||
return False
|
||||
else:
|
||||
@@ -348,7 +351,11 @@ class SharkModuleTest(unittest.TestCase):
|
||||
self.pytestconfig.getoption("dispatch_benchmarks_dir")
|
||||
)
|
||||
|
||||
if config["xfail_cpu"] == "True" and device == "cpu":
|
||||
if config["xfail_cpu"] == "True" and device in [
|
||||
"cpu",
|
||||
"cpu-sync",
|
||||
"cpu-task",
|
||||
]:
|
||||
pytest.xfail(reason=config["xfail_reason"])
|
||||
|
||||
if config["xfail_cuda"] == "True" and device == "cuda":
|
||||
|
||||
@@ -1,23 +1,28 @@
|
||||
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
|
||||
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
microsoft/MiniLM-L12-H384-uncased,True,hf,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
|
||||
bert-base-uncased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
google/mobilebert-uncased,True,hf,True,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
|
||||
alexnet,False,vision,True,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
|
||||
resnet18,False,vision,True,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
|
||||
resnet50,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
resnet101,False,vision,True,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
squeezenet1_0,False,vision,True,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
|
||||
wide_resnet50_2,False,vision,True,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
|
||||
mobilenet_v3_small,False,vision,True,2.5M,"image-classification,cnn,mobile",N/A
|
||||
google/vit-base-patch16-224,True,hf_img_cls,False,86M,"image-classification,vision-transformer,transformer-encoder",N/A
|
||||
microsoft/resnet-50,True,hf_img_cls,False,23M,"image-classification,cnn,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,22M,"image-classification,vision-transformer,cnn",N/A
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
|
||||
nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encoder",SegFormer
|
||||
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
model_name, use_tracing, model_type, dynamic, mlir_type, decompose, param_count, tags, notes
|
||||
efficientnet_b0,True,vision,False,linalg,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,vision,False,linalg,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
microsoft/MiniLM-L12-H384-uncased,True,hf,True,linalg,False,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
|
||||
bert-base-uncased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
google/mobilebert-uncased,True,hf,True,linalg,False,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
|
||||
alexnet,False,vision,True,linalg,False,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
|
||||
resnet18,False,vision,True,linalg,False,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
|
||||
resnet50,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
resnet101,False,vision,True,linalg,False,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
squeezenet1_0,False,vision,True,linalg,False,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
|
||||
wide_resnet50_2,False,vision,True,linalg,False,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
|
||||
mobilenet_v3_small,False,vision,True,linalg,False,2.5M,"image-classification,cnn,mobile",N/A
|
||||
google/vit-base-patch16-224,True,hf_img_cls,False,linalg,False,86M,"image-classification,vision-transformer,transformer-encoder",N/A
|
||||
microsoft/resnet-50,True,hf_img_cls,False,linalg,False,23M,"image-classification,cnn,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,linalg,False,22M,"image-classification,vision-transformer,cnn",N/A
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,linalg,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
|
||||
nvidia/mit-b0,True,hf_img_cls,False,linalg,False,3.7M,"image-classification,transformer-encoder",SegFormer
|
||||
mnasnet1_0,False,vision,True,linalg,False,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
|
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"version": "2023-03-31_02d52bb"
|
||||
"version": "nightly"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user