Compare commits

...

58 Commits

Author SHA1 Message Date
Vivek Khandelwal
d7092aafaa Fix multiple issue for Langchain
This commit fixes the following issue for the Langchain:
1.) Web UI not able to fetch results.
2.) For each query model getting reloaded.
3.) SHARK module not using user provided device and precision.
4.) Create a class for main Langchain code.
5.) Misc issues
2023-07-21 21:56:27 +05:30
Vivek Khandelwal
a415f3f70e Fix Langchain Prompt issue and add web UI support (#1682) 2023-07-21 06:36:55 -07:00
Vivek Khandelwal
c292e5c9d7 Add Langchain CPU support and update requirements 2023-07-20 18:53:34 +05:30
Vivek Khandelwal
03c4d9e171 Add support for Llama-2-70b for web and cli, and for hf_auth_token 2023-07-20 14:57:48 +05:30
jinchen62
3662224c04 Update brevitas requirement (#1677)
also clean up useless args

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-19 22:03:32 -07:00
Vivek Khandelwal
db3f222933 Revert "Add Llama2 70B option in CLI and WebUI (#1673)" (#1679)
This reverts commit 41e5088908.
2023-07-19 22:02:48 -07:00
Stefan Kapusniak
68b3021325 Fixes cosmetic problems with Gradio 3.37.0 (#1676)
* Fix nod-ai logo having a white border
* Fix control labels having a black background
* Remove extra lower border below Save Prompt checkboxes in Txt2Img UI
2023-07-19 17:28:53 -07:00
AyaanShah2204
336469154d added copy-metadata for pyyaml (#1678) 2023-07-19 17:27:25 -07:00
Abhishek Varma
41e5088908 Add Llama2 70B option in CLI and WebUI (#1673) 2023-07-19 10:41:42 -07:00
PhaneeshB
0a8f7673f4 Add README for CodeGen server 2023-07-19 23:10:23 +05:30
PhaneeshB
c482ab78da fix second vic clearing for low mem device 2023-07-19 23:10:23 +05:30
Vivek Khandelwal
4be80f7158 Add support for the Llama-2 model 2023-07-19 20:57:08 +05:30
AyaanShah2204
536aba1424 unpinned torch_mlir (#1668)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-19 06:28:00 -07:00
Ean Garvey
dd738a0e02 small changes to opt_perf_comparison.py (#1670)
* Use longer prompts for OPT comparison script

* small tweaks
2023-07-19 06:26:50 -07:00
Daniel Garvey
8927cb0a2c set optional vmfb download (#1667) 2023-07-18 10:57:28 -07:00
Daniel Garvey
8c317e4809 fix cli for vicuna (#1666) 2023-07-18 10:03:40 -07:00
Vivek Khandelwal
b0136593df Add support for different compilation paths for DocuChat (#1665) 2023-07-18 09:49:44 -07:00
Vivek Khandelwal
11f62d7fac Minor fixes for MiniLM Training 2023-07-18 17:16:44 +05:30
powderluv
14559dd620 Update DocuChat as experimental (#1660) 2023-07-17 22:12:05 -07:00
AyaanShah2204
e503a3e8d6 fixed joblib import error (#1659) 2023-07-17 12:56:10 -07:00
AyaanShah2204
22a4254adf fixed pyinstaller path for langchain imports (#1658) 2023-07-17 12:19:21 -07:00
Vivek Khandelwal
ab01f0f048 Add Langchain model in SHARK (#1657)
* Add H2OGPT

* Add UI tab for h2ogpt

* Add source files from h2ogpt

* Add the rest of the files

* Add h2ogpt support

* Add SHARK Compilation support for langchain model for cli mode

---------

Co-authored-by: George Petterson <gpetters@protonmail.com>
2023-07-17 09:58:15 -07:00
Phaneesh Barwaria
c471d17cca codegen API (#1655) 2023-07-16 20:00:39 -07:00
Stefan Kapusniak
a2a436eb0c SD - Add repeatable (batch) seeds option (#1654)
* Generates the seeds for all batch_count batches being run up front
rather than generating the seed for a batch just before it is run.
* Adds a --repeatable_seeds argument defaulting to False
* When repeatable_seeds=True, the first seed for a set of batches will
also be used as the rng seed for the subsequent batch seeds in the run.
The rng seed is then reset.
* When repeatable_seeds=False, batch seeding works as currently.
* Update scripts under apps/scripts that support the batch_count
argument to also support the repeatable_seeds argument.
* UI/Web: Adds a checkbox element on each SD tab after batch count/size
for toggling repeatable seeds, and update _inf functions to take
this into account.
* UI/Web: Moves the Stop buttons out of the Advanced sections and next
to Generate to make things not fit quite so badly with the extra UI
elements.
* UI/Web: Fixes logging to the upscaler output text box not working
correctly when running multiple batches.
2023-07-15 16:22:41 -07:00
powderluv
1adb51b29d Update docker README.md 2023-07-15 14:31:56 -07:00
anush elangovan
aab2233e25 Add a dev Ubuntu 22.04 docker image 2023-07-15 16:25:37 +00:00
jinchen62
e20cd71314 Change to a separate pass to unpack quantized weights (#1652) 2023-07-15 04:54:53 -07:00
powderluv
5ec91143f5 add a HF accelerate requirement (#1651) 2023-07-14 05:56:12 -07:00
Ean Garvey
7cf19230e2 add perf comparison script for opt. (#1650) 2023-07-13 13:29:48 -05:00
powderluv
1bcf6b2c5b pin diffusers to 0.18.1 (#1648) 2023-07-13 01:02:24 -07:00
jinchen62
91027f8719 Remove done TODOs, a sup PR for #1644 (#1647) 2023-07-12 23:30:45 -07:00
powderluv
a909fc2e78 add tiktoken to spec file (#1646) 2023-07-12 16:12:02 -07:00
jinchen62
247f69cf9d Apply canonicalize for unpacking int4 (#1644)
- tested it unpacks int4 as expected
- tested it doesn't make difference on int8
2023-07-11 19:41:09 -07:00
PhaneeshB
3b8f7cc231 Add codegen support in UI + lint 2023-07-11 21:58:01 +05:30
PhaneeshB
6e8dbf72bd mlir/vmfb path fixes for vic pipeline 2023-07-11 21:58:01 +05:30
PhaneeshB
38e5b62d80 adapt UI to send model details to pipeline 2023-07-11 21:58:01 +05:30
PhaneeshB
1c7eecc981 add codegen support in vic pipeline 2023-07-11 21:58:01 +05:30
PhaneeshB
be417f0bf4 fix precision for fp16 2023-07-11 21:58:01 +05:30
AyaanShah2204
a517e217b0 Added support for building ZIP distributions (#1639)
* added support for zip files

* making linter happy

* Added temporary fix for NoneType padding

* Removed zip script

* Added shared imports file

* making linter happy
2023-07-09 06:45:36 -07:00
Ranvir Singh Virk
9fcae4f808 Metal testing (#1595)
* Fixing metal_platform and device selection

* fixing for metal platform

* fixed for black lint formating
2023-07-08 15:22:53 -07:00
Stefan Kapusniak
788d469c5b UI/Web Refix remaining gradio deprecation warning (#1638) 2023-07-08 13:48:36 -07:00
Stefan Kapusniak
8a59f7cc27 UI/Web add 'open folder' button to output gallery (#1634)
* Adds a button that opens the currently selected subdirectory using
the default OS file manager
* Improve output gallery handling of having images deleted out from
under it.
* Don't show VAE or LoRA lines in parameter info panel when their
value is 'None'
* Use a css class for small icon buttons on the output gallery
tab instead using the same id for multiple buttons
2023-07-08 12:44:59 -07:00
Stefan Kapusniak
1c2ec3c7a2 Some Fixes for Gradio 3.36.1 (#1637)
* Clear .style deprecation warnings.
* Re-remove download button from Nod logos.
* Add work around for `container=false` not doing what it did before on
dropdowns to the output gallery CSS
2023-07-08 11:20:34 -07:00
powderluv
af0f715e20 Unpin gradio 2023-07-08 09:41:14 -07:00
jinchen62
47ec7275e6 Fix brevitas quantize argument (#1633) 2023-07-07 11:30:31 -07:00
powderluv
3a24cff901 change binary names 2023-07-06 23:59:14 -07:00
powderluv
1f72907886 Fix the pyinstaller for chatbots (#1631) 2023-07-06 23:30:01 -07:00
Daniel Garvey
06c8aabd01 remove local-sync from webui (#1629) 2023-07-06 13:58:59 -07:00
Phaneesh Barwaria
55a12cc0c4 cpu name in device (#1628)
* show cpu name in devices

* change device order for chatbot
2023-07-06 12:00:09 -07:00
Ean Garvey
7dcbbde523 Xfail models for data tiling flag changes (#1624) 2023-07-06 06:57:17 -07:00
Abhishek Varma
1b62dc4529 [Vicuna] Revert the formatting for Brevitas op (#1626)
-- This commit reverts the formatting for Brevitas op.
-- It also excludes vicuna.py script from `black` formatter.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-07-06 06:56:17 -07:00
Daniel Garvey
c5a47887f4 Revert revert negative prompt change (#1625)
* revert default flag changes

* revert revert negative prompt change

* revert revert negative prompt change
2023-07-05 22:09:06 -07:00
Daniel Garvey
c72d0eaf87 revert default flag changes (#1622) 2023-07-05 15:43:26 -05:00
powderluv
c41f58042a Update compile_utils.py (#1617)
* Update compile_utils.py

* Update compile_utils.py

* Update compile_utils.py
2023-07-05 10:06:48 -07:00
xzuyn
043e5a5c7a fix a mistake I made, and more formatting changes, and add ++/Karras (#1619)
* fixed missing line break in `stablelm_ui.py` `start_message`
- also more formatting changes

* fix variable spelling mistake

* revert some formatting cause black wants it different

* one less line, still less than 79

* add ++, karras, and karras++ types of dpmsolver.

* black line length 79

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-05 09:00:16 -07:00
Abhishek Varma
a1b1ce935c int8 e2e for WebUI (#1620) 2023-07-05 07:08:36 -07:00
jinchen62
bc6fee1a0c Add int4/int8 vicuna (#1598) 2023-07-05 07:01:51 -07:00
xzuyn
91ab594744 minor fix, some changes, some additions, and cleaning up (#1618)
* - fix overflowing text (a janky fix)
- add DEISMultistep scheduler as an option
- set default scheduler to DEISMultistep
- set default CFG to 3.5
- set default steps to 16
- add `xzuyn/PhotoMerge` as a model option
- add 3 new example prompts (which work nicely with PhotoMerge)
- formatting

* Set DEISMultistep in the cpu_only list instead

* formatting

* formatting

* modify prompts

* resize window to 81% & 85% monitor resolution instead of (WxH / 1.0625).

* increase steps to 32 after some testing. somewhere in between 16 and 32 is best compromise on speed/quality for DEIS, so 32 steps to play it safe.

* black line length 79

* revert settings DEIS as default scheduler.

* add more schedulers & revert accidental DDIM change
- add DPMSolverSingleStep, KDPM2AncestralDiscrete, & HeunDiscrete.
- did not add `DPMSolverMultistepInverse` or `DDIMInverse` as they only output latent noise, there are a few I did not try adding yet.
- accidentally set `upscaler_ui.py` to EulerDiscrete by default last commit while reverting DEIS changes.
- add `xzuyn/PhotoMerge-inpainting` as an in or out painting model.

* black line length 79

* add help section stuff and some other changes
- list the rest of the schedulers in argument help section.
- replace mutable default arguments.
- increased default window height to 91% to remove any scrolling for the main txt2img page (tested on a 1920x1080 monitor). width is the same as its just enough to have the image output on the side instead of the bottom.
- cleanup
2023-07-04 18:51:23 -07:00
89 changed files with 23126 additions and 1996 deletions

View File

@@ -2,4 +2,4 @@
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py

View File

@@ -54,8 +54,8 @@ jobs:
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/shark_sd.exe ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets

4
.gitignore vendored
View File

@@ -189,3 +189,7 @@ apps/stable_diffusion/web/models/
# Stencil annotators.
stencil_annotator/
# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData

View File

@@ -0,0 +1,16 @@
## CodeGen Setup using SHARK-server
### Setup Server
- clone SHARK and setup the venv
- host the server using `python apps/stable_diffusion/web/index.py --api --server_port=<PORT>`
- default server address is `http://0.0.0.0:8080`
### Setup Client
1. fauxpilot-vscode (VSCode Extension):
- Code for the extension can be found [here](https://github.com/Venthe/vscode-fauxpilot)
- PreReq: VSCode extension (will need [`nodejs` and `npm`](https://nodejs.org/en/download) to compile and run the extension)
- Compile and Run the extension on VSCode (press F5 on VSCode), this opens a new VSCode window with the extension running
- Open VSCode settings, search for fauxpilot in settings and modify `server : http://<IP>:<PORT>`, `Model : codegen` , `Max Lines : 30`
2. Others (REST API curl, OpenAI Python bindings) as shown [here](https://github.com/fauxpilot/fauxpilot/blob/main/documentation/client.md)
- using Github Copilot VSCode extension with SHARK-server needs more work to be functional.

View File

@@ -0,0 +1,17 @@
# Langchain
## How to run the model
1.) Install all the dependencies by running:
```shell
pip install -r apps/language_models/langchain/langchain_requirements.txt
```
2.) Create a folder named `user_path` in `apps/language_models/langchain/` directory.
Now, you are ready to use the model.
3.) To run the model, run the following command:
```shell
python apps/language_models/langchain/gen.py --cli=True
```

View File

@@ -0,0 +1,186 @@
import copy
import torch
from evaluate_params import eval_func_param_names
from gen import Langchain
from prompter import non_hf_types
from utils import clear_torch_cache, NullContext, get_kwargs
def run_cli( # for local function:
base_model=None,
lora_weights=None,
inference_server=None,
debug=None,
chat_context=None,
examples=None,
memory_restriction_level=None,
# for get_model:
score_model=None,
load_8bit=None,
load_4bit=None,
load_half=None,
load_gptq=None,
use_safetensors=None,
infer_devices=None,
tokenizer_base_model=None,
gpu_id=None,
local_files_only=None,
resume_download=None,
use_auth_token=None,
trust_remote_code=None,
offload_folder=None,
compile_model=None,
# for some evaluate args
stream_output=None,
prompt_type=None,
prompt_dict=None,
temperature=None,
top_p=None,
top_k=None,
num_beams=None,
max_new_tokens=None,
min_new_tokens=None,
early_stopping=None,
max_time=None,
repetition_penalty=None,
num_return_sequences=None,
do_sample=None,
chat=None,
langchain_mode=None,
langchain_action=None,
document_choice=None,
top_k_docs=None,
chunk=None,
chunk_size=None,
# for evaluate kwargs
src_lang=None,
tgt_lang=None,
concurrency_count=None,
save_dir=None,
sanitize_bot_response=None,
model_state0=None,
max_max_new_tokens=None,
is_public=None,
max_max_time=None,
raise_generate_gpu_exceptions=None,
load_db_if_exists=None,
dbs=None,
user_path=None,
detect_user_path_changes_every_query=None,
use_openai_embedding=None,
use_openai_model=None,
hf_embedding_model=None,
db_type=None,
n_jobs=None,
first_para=None,
text_limit=None,
verbose=None,
cli=None,
reverse_docs=None,
use_cache=None,
auto_reduce_chunks=None,
max_chunks=None,
model_lock=None,
force_langchain_evaluate=None,
model_state_none=None,
# unique to this function:
cli_loop=None,
):
Langchain.check_locals(**locals())
score_model = "" # FIXME: For now, so user doesn't have to pass
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
device = "cpu" if n_gpus == 0 else "cuda"
context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
with context_class(device):
from functools import partial
# get score model
smodel, stokenizer, sdevice = Langchain.get_score_model(
reward_type=True,
**get_kwargs(
Langchain.get_score_model,
exclude_names=["reward_type"],
**locals()
)
)
model, tokenizer, device = Langchain.get_model(
reward_type=False,
**get_kwargs(
Langchain.get_model, exclude_names=["reward_type"], **locals()
)
)
model_dict = dict(
base_model=base_model,
tokenizer_base_model=tokenizer_base_model,
lora_weights=lora_weights,
inference_server=inference_server,
prompt_type=prompt_type,
prompt_dict=prompt_dict,
)
model_state = dict(model=model, tokenizer=tokenizer, device=device)
model_state.update(model_dict)
my_db_state = [None]
fun = partial(
Langchain.evaluate,
model_state,
my_db_state,
**get_kwargs(
Langchain.evaluate,
exclude_names=["model_state", "my_db_state"]
+ eval_func_param_names,
**locals()
)
)
example1 = examples[-1] # pick reference example
all_generations = []
while True:
clear_torch_cache()
instruction = input("\nEnter an instruction: ")
if instruction == "exit":
break
eval_vars = copy.deepcopy(example1)
eval_vars[eval_func_param_names.index("instruction")] = eval_vars[
eval_func_param_names.index("instruction_nochat")
] = instruction
eval_vars[eval_func_param_names.index("iinput")] = eval_vars[
eval_func_param_names.index("iinput_nochat")
] = "" # no input yet
eval_vars[
eval_func_param_names.index("context")
] = "" # no context yet
# grab other parameters, like langchain_mode
for k in eval_func_param_names:
if k in locals():
eval_vars[eval_func_param_names.index(k)] = locals()[k]
gener = fun(*tuple(eval_vars))
outr = ""
res_old = ""
for gen_output in gener:
res = gen_output["response"]
extra = gen_output["sources"]
if base_model not in non_hf_types or base_model in ["llama"]:
if not stream_output:
print(res)
else:
# then stream output for gradio that has full output each generation, so need here to show only new chars
diff = res[len(res_old) :]
print(diff, end="", flush=True)
res_old = res
outr = res # don't accumulate
else:
outr += res # just is one thing
if extra:
# show sources at end after model itself had streamed to std rest of response
print(extra, flush=True)
all_generations.append(outr + "\n")
if not cli_loop:
break
return all_generations

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,103 @@
from enum import Enum
class PromptType(Enum):
custom = -1
plain = 0
instruct = 1
quality = 2
human_bot = 3
dai_faq = 4
summarize = 5
simple_instruct = 6
instruct_vicuna = 7
instruct_with_end = 8
human_bot_orig = 9
prompt_answer = 10
open_assistant = 11
wizard_lm = 12
wizard_mega = 13
instruct_vicuna2 = 14
instruct_vicuna3 = 15
wizard2 = 16
wizard3 = 17
instruct_simple = 18
wizard_vicuna = 19
openai = 20
openai_chat = 21
gptj = 22
prompt_answer_openllama = 23
vicuna11 = 24
mptinstruct = 25
mptchat = 26
falcon = 27
class DocumentChoices(Enum):
All_Relevant = 0
All_Relevant_Only_Sources = 1
Only_All_Sources = 2
Just_LLM = 3
non_query_commands = [
DocumentChoices.All_Relevant_Only_Sources.name,
DocumentChoices.Only_All_Sources.name,
]
class LangChainMode(Enum):
"""LangChain mode"""
DISABLED = "Disabled"
CHAT_LLM = "ChatLLM"
LLM = "LLM"
ALL = "All"
WIKI = "wiki"
WIKI_FULL = "wiki_full"
USER_DATA = "UserData"
MY_DATA = "MyData"
GITHUB_H2OGPT = "github h2oGPT"
H2O_DAI_DOCS = "DriverlessAI docs"
class LangChainAction(Enum):
"""LangChain action"""
QUERY = "Query"
# WIP:
# SUMMARIZE_MAP = "Summarize_map_reduce"
SUMMARIZE_MAP = "Summarize"
SUMMARIZE_ALL = "Summarize_all"
SUMMARIZE_REFINE = "Summarize_refine"
no_server_str = no_lora_str = no_model_str = "[None/Remove]"
# from site-packages/langchain/llms/openai.py
# but needed since ChatOpenAI doesn't have this information
model_token_mapping = {
"gpt-4": 8192,
"gpt-4-0314": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16 * 1024,
"gpt-3.5-turbo-0301": 4096,
"text-ada-001": 2049,
"ada": 2049,
"text-babbage-001": 2040,
"babbage": 2049,
"text-curie-001": 2049,
"curie": 2049,
"davinci": 2049,
"text-davinci-003": 4097,
"text-davinci-002": 4097,
"code-davinci-002": 8001,
"code-davinci-001": 8001,
"code-cushman-002": 2048,
"code-cushman-001": 2048,
}
source_prefix = "Sources [Score | Link]:"
source_postfix = "End Sources<p>"

View File

@@ -0,0 +1,406 @@
import inspect
import os
import traceback
import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from evaluate_params import eval_func_param_names, eval_extra_columns
from gen import Langchain
from prompter import Prompter
from utils import clear_torch_cache, NullContext, get_kwargs
def run_eval( # for local function:
base_model=None,
lora_weights=None,
inference_server=None,
prompt_type=None,
prompt_dict=None,
debug=None,
chat=False,
chat_context=None,
stream_output=None,
eval_filename=None,
eval_prompts_only_num=None,
eval_prompts_only_seed=None,
eval_as_output=None,
examples=None,
memory_restriction_level=None,
# for get_model:
score_model=None,
load_8bit=None,
load_4bit=None,
load_half=None,
load_gptq=None,
use_safetensors=None,
infer_devices=None,
tokenizer_base_model=None,
gpu_id=None,
local_files_only=None,
resume_download=None,
use_auth_token=None,
trust_remote_code=None,
offload_folder=None,
compile_model=None,
# for evaluate args beyond what's already above, or things that are always dynamic and locally created
temperature=None,
top_p=None,
top_k=None,
num_beams=None,
max_new_tokens=None,
min_new_tokens=None,
early_stopping=None,
max_time=None,
repetition_penalty=None,
num_return_sequences=None,
do_sample=None,
langchain_mode=None,
langchain_action=None,
top_k_docs=None,
chunk=None,
chunk_size=None,
document_choice=None,
# for evaluate kwargs:
src_lang=None,
tgt_lang=None,
concurrency_count=None,
save_dir=None,
sanitize_bot_response=None,
model_state0=None,
max_max_new_tokens=None,
is_public=None,
max_max_time=None,
raise_generate_gpu_exceptions=None,
load_db_if_exists=None,
dbs=None,
user_path=None,
detect_user_path_changes_every_query=None,
use_openai_embedding=None,
use_openai_model=None,
hf_embedding_model=None,
db_type=None,
n_jobs=None,
first_para=None,
text_limit=None,
verbose=None,
cli=None,
reverse_docs=None,
use_cache=None,
auto_reduce_chunks=None,
max_chunks=None,
model_lock=None,
force_langchain_evaluate=None,
model_state_none=None,
):
Langchain.check_locals(**locals())
if eval_prompts_only_num > 0:
np.random.seed(eval_prompts_only_seed)
example1 = examples[-1] # pick reference example
examples = []
responses = []
if eval_filename is None:
# override default examples with shareGPT ones for human-level eval purposes only
eval_filename = (
"ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json"
)
if not os.path.isfile(eval_filename):
os.system(
"wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s"
% eval_filename
)
import json
data = json.load(open(eval_filename, "rt"))
# focus on data that starts with human, else likely chopped from other data
turn_start = 0 # odd in general
data = [
x
for x in data
if len(x["conversations"]) > turn_start + 1
and x["conversations"][turn_start]["from"] == "human"
and x["conversations"][turn_start + 1]["from"] == "gpt"
]
for i in sorted(
np.random.randint(0, len(data), size=eval_prompts_only_num)
):
assert data[i]["conversations"][turn_start]["from"] == "human"
instruction = data[i]["conversations"][turn_start]["value"]
assert (
data[i]["conversations"][turn_start + 1]["from"] == "gpt"
)
output = data[i]["conversations"][turn_start + 1]["value"]
examplenew = example1.copy()
assert (
not chat
), "No gradio must use chat=False, uses nochat instruct"
examplenew[
eval_func_param_names.index("instruction_nochat")
] = instruction
examplenew[
eval_func_param_names.index("iinput_nochat")
] = "" # no input
examplenew[
eval_func_param_names.index("context")
] = Langchain.get_context(chat_context, prompt_type)
examples.append(examplenew)
responses.append(output)
else:
# get data, assume in correct format: json of rows of dict of instruction and output
# only instruction is required
import json
data = json.load(open(eval_filename, "rt"))
for i in sorted(
np.random.randint(0, len(data), size=eval_prompts_only_num)
):
examplenew = example1.copy()
instruction = data[i]["instruction"]
output = data[i].get("output", "") # not required
assert (
not chat
), "No gradio must use chat=False, uses nochat instruct"
examplenew[
eval_func_param_names.index("instruction_nochat")
] = instruction
examplenew[
eval_func_param_names.index("iinput_nochat")
] = "" # no input
examplenew[
eval_func_param_names.index("context")
] = Langchain.get_context(chat_context, prompt_type)
examples.append(examplenew)
responses.append(output)
num_examples = len(examples)
scoring_path = "scoring"
os.makedirs(scoring_path, exist_ok=True)
if eval_as_output:
used_base_model = "gpt35"
used_lora_weights = ""
used_inference_server = ""
else:
used_base_model = str(base_model.split("/")[-1])
used_lora_weights = str(lora_weights.split("/")[-1])
used_inference_server = str(inference_server.split("/")[-1])
eval_out_filename = "df_scores_%s_%s_%s_%s_%s_%s_%s.parquet" % (
num_examples,
eval_prompts_only_num,
eval_prompts_only_seed,
eval_as_output,
used_base_model,
used_lora_weights,
used_inference_server,
)
eval_out_filename = os.path.join(scoring_path, eval_out_filename)
# torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
device = "cpu" if n_gpus == 0 else "cuda"
context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
with context_class(device):
# ensure was set right above before examples generated
assert (
not stream_output
), "stream_output=True does not make sense with example loop"
import time
from functools import partial
# get score model
smodel, stokenizer, sdevice = Langchain.get_score_model(
reward_type=True,
**get_kwargs(
Langchain.get_score_model,
exclude_names=["reward_type"],
**locals()
)
)
if not eval_as_output:
model, tokenizer, device = Langchain.get_model(
reward_type=False,
**get_kwargs(
Langchain.get_model,
exclude_names=["reward_type"],
**locals()
)
)
model_dict = dict(
base_model=base_model,
tokenizer_base_model=tokenizer_base_model,
lora_weights=lora_weights,
inference_server=inference_server,
prompt_type=prompt_type,
prompt_dict=prompt_dict,
)
model_state = dict(model=model, tokenizer=tokenizer, device=device)
model_state.update(model_dict)
my_db_state = [None]
fun = partial(
Langchain.evaluate,
model_state,
my_db_state,
**get_kwargs(
Langchain.evaluate,
exclude_names=["model_state", "my_db_state"]
+ eval_func_param_names,
**locals()
)
)
else:
assert eval_prompts_only_num > 0
def get_response(*args, exi=0):
# assumes same ordering of examples and responses
yield responses[exi]
fun = get_response
t0 = time.time()
score_dump = []
score_avg = 0
score_median = 0
for exi, ex in enumerate(examples):
clear_torch_cache()
instruction = ex[eval_func_param_names.index("instruction_nochat")]
iinput = ex[eval_func_param_names.index("iinput_nochat")]
context = ex[eval_func_param_names.index("context")]
clear_torch_cache()
print("")
print("START" + "=" * 100)
print(
"Question: %s %s"
% (instruction, ("input=%s" % iinput if iinput else ""))
)
print("-" * 105)
# fun yields as generator, so have to iterate over it
# Also means likely do NOT want --stream_output=True, else would show all generations
t1 = time.time()
gener = (
fun(*tuple(ex), exi=exi) if eval_as_output else fun(*tuple(ex))
)
for res_fun in gener:
res = res_fun["response"]
extra = res_fun["sources"]
print(res)
if smodel:
score_with_prompt = False
if score_with_prompt:
data_point = dict(
instruction=instruction,
input=iinput,
context=context,
)
prompter = Prompter(
prompt_type,
prompt_dict,
debug=debug,
chat=chat,
stream_output=stream_output,
)
prompt = prompter.generate_prompt(data_point)
else:
# just raw input and output
if eval_prompts_only_num > 0:
# only our own examples have this filled at moment
assert iinput in [
None,
"",
], iinput # should be no iinput
if not (chat_context and prompt_type == "human_bot"):
assert context in [
None,
"",
], context # should be no context
prompt = instruction
if memory_restriction_level > 0:
cutoff_len = (
768 if memory_restriction_level <= 2 else 512
)
else:
cutoff_len = tokenizer.model_max_length
inputs = stokenizer(
prompt,
res,
return_tensors="pt",
truncation=True,
max_length=cutoff_len,
)
try:
score = (
torch.sigmoid(smodel(**inputs).logits[0].float())
.cpu()
.detach()
.numpy()[0]
)
except torch.cuda.OutOfMemoryError as e:
print(
"GPU OOM 1: question: %s answer: %s exception: %s"
% (prompt, res, str(e)),
flush=True,
)
traceback.print_exc()
score = 0.0
clear_torch_cache()
except (Exception, RuntimeError) as e:
if (
"Expected all tensors to be on the same device"
in str(e)
or "expected scalar type Half but found Float"
in str(e)
or "probability tensor contains either" in str(e)
or "cublasLt ran into an error!" in str(e)
):
print(
"GPU error: question: %s answer: %s exception: %s"
% (prompt, res, str(e)),
flush=True,
)
traceback.print_exc()
score = 0.0
clear_torch_cache()
else:
raise
score_dump.append(ex + [prompt, res, score])
# dump every score in case abort
df_scores = pd.DataFrame(
score_dump,
columns=eval_func_param_names + eval_extra_columns,
)
df_scores.to_parquet(eval_out_filename, index=False)
# plot histogram so far
plt.figure(figsize=(10, 10))
plt.hist(df_scores["score"], bins=20)
score_avg = np.mean(df_scores["score"])
score_median = np.median(df_scores["score"])
print(
"SCORE %s: %s So far: AVG: %s MEDIAN: %s"
% (exi, score, score_avg, score_median),
flush=True,
)
plt.title(
"Score avg: %s median: %s" % (score_avg, score_median)
)
plt.savefig(eval_out_filename.replace(".parquet", ".png"))
plt.close()
print("END" + "=" * 102)
print("")
t2 = time.time()
print(
"Time taken for example: %s Time taken so far: %.4f about %.4g per example"
% (t2 - t1, t2 - t0, (t2 - t0) / (1 + exi))
)
t1 = time.time()
print(
"Total time taken: %.4f about %.4g per example"
% (t1 - t0, (t1 - t0) / num_examples)
)
print(
"Score avg: %s median: %s" % (score_avg, score_median), flush=True
)
return eval_out_filename

View File

@@ -0,0 +1,53 @@
no_default_param_names = [
"instruction",
"iinput",
"context",
"instruction_nochat",
"iinput_nochat",
]
gen_hyper = [
"temperature",
"top_p",
"top_k",
"num_beams",
"max_new_tokens",
"min_new_tokens",
"early_stopping",
"max_time",
"repetition_penalty",
"num_return_sequences",
"do_sample",
]
eval_func_param_names = (
[
"instruction",
"iinput",
"context",
"stream_output",
"prompt_type",
"prompt_dict",
]
+ gen_hyper
+ [
"chat",
"instruction_nochat",
"iinput_nochat",
"langchain_mode",
"langchain_action",
"top_k_docs",
"chunk",
"chunk_size",
"document_choice",
]
)
# form evaluate defaults for submit_nochat_api
eval_func_param_names_defaults = eval_func_param_names.copy()
for k in no_default_param_names:
if k in eval_func_param_names_defaults:
eval_func_param_names_defaults.remove(k)
eval_extra_columns = ["prompt", "response", "score"]

View File

@@ -0,0 +1,283 @@
import os
import json
import shutil
import subprocess
import torch
from peft import PeftModel
from transformers import PreTrainedModel
def do_export():
BASE_MODEL = "h2oai/h2ogpt-oasst1-512-12b"
LORA_WEIGHTS = "h2ogpt-oasst1-512-12b.h2oaih2ogpt-oig-oasst1-instruct-cleaned-v3.1_epochs.805b8e8eff369207340a5a6f90f3c833f9731254.2"
OUTPUT_NAME = "h2ogpt-oig-oasst1-512-12b"
BASE_MODEL = "EleutherAI/pythia-12b-deduped"
LORA_WEIGHTS = "pythia-12b-deduped.h2oaiopenassistant_oasst1_h2ogpt_graded.3_epochs.2ccf687ea3f3f3775a501838e81c1a0066430455.4"
OUTPUT_NAME = "h2ogpt-oasst1-512-12b"
BASE_MODEL = "tiiuae/falcon-40b"
LORA_WEIGHTS = "falcon-40b.h2oaiopenassistant_oasst1_h2ogpt.1_epochs.894d8450d35c180cd03222a45658d04c15b78d4b.9"
OUTPUT_NAME = "h2ogpt-oasst1-2048-falcon-40b"
# BASE_MODEL = 'decapoda-research/llama-65b-hf'
# LORA_WEIGHTS = 'llama-65b-hf.h2oaiopenassistant_oasst1_h2ogpt_graded.1_epochs.113510499324f0f007cbec9d9f1f8091441f2469.3'
# OUTPUT_NAME = "h2ogpt-research-oasst1-llama-65b"
model = os.getenv("MODEL")
# for testing
if model:
BASE_MODEL = "tiiuae/falcon-7b"
LORA_WEIGHTS = model + ".lora"
OUTPUT_NAME = model
llama_type = "llama" in BASE_MODEL
as_pytorch = False # False -> HF
from loaders import get_loaders
model_loader, tokenizer_loader = get_loaders(
model_name=BASE_MODEL, reward_type=False, llama_type=llama_type
)
tokenizer = tokenizer_loader.from_pretrained(
BASE_MODEL,
local_files_only=False,
resume_download=True,
)
tokenizer.save_pretrained(OUTPUT_NAME)
base_model = model_loader(
BASE_MODEL,
load_in_8bit=False,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map={"": "cpu"},
)
print(base_model)
if llama_type:
layers = base_model.model.layers
first_weight = layers[0].self_attn.q_proj.weight
else:
if any(
[x in BASE_MODEL.lower() for x in ["pythia", "h2ogpt", "gpt-neox"]]
):
layers = base_model.gpt_neox.base_model.layers
first_weight = layers[0].attention.query_key_value.weight
elif any([x in BASE_MODEL.lower() for x in ["falcon"]]):
first_weight = base_model.transformer.h._modules[
"0"
].self_attention.query_key_value.weight
else:
layers = base_model.transformer.base_model.h
first_weight = layers[0].attn.q_proj.weight
first_weight_old = first_weight.clone()
lora_model = PeftModel.from_pretrained(
base_model,
LORA_WEIGHTS,
device_map={"": "cpu"},
torch_dtype=torch.float16,
)
assert torch.allclose(first_weight_old, first_weight)
# merge weights TODO: include all lora_target_modules, not just default ones
if llama_type:
lora_model = lora_model.merge_and_unload()
# for layer in lora_model.base_model.model.model.layers:
# layer.self_attn.q_proj.merge_weights = True
# layer.self_attn.k_proj.merge_weights = True
# layer.self_attn.v_proj.merge_weights = True
# layer.self_attn.o_proj.merge_weights = True
else:
if any(
[x in BASE_MODEL.lower() for x in ["pythia", "h2ogpt", "gpt-neox"]]
):
for layer in lora_model.base_model.gpt_neox.base_model.layers:
layer.attention.query_key_value.merge_weights = True
else:
lora_model.merge_and_unload()
# for layer in lora_model.base_model.transformer.base_model.h:
# layer.attn.q_proj.merge_weights = True
# layer.attn.v_proj.merge_weights = True
lora_model.train(False)
# did we do anything?
assert not torch.allclose(first_weight_old, first_weight)
lora_model_sd = lora_model.state_dict()
if as_pytorch:
# FIXME - might not be generic enough still
params = {
"dim": base_model.config.hidden_size,
"n_heads": base_model.config.num_attention_heads,
"n_layers": base_model.config.num_hidden_layers,
"norm_eps": base_model.config.layer_norm_eps,
"vocab_size": base_model.config.vocab_size,
}
n_layers = params["n_layers"]
n_heads = params["n_heads"]
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = 1.0 / (
base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)
)
def permute(w):
return (
w.view(n_heads, dim // n_heads // 2, 2, dim)
.transpose(1, 2)
.reshape(dim, dim)
)
def unpermute(w):
return (
w.view(n_heads, 2, dim // n_heads // 2, dim)
.transpose(1, 2)
.reshape(dim, dim)
)
def translate_state_dict_key(k):
if "gpt-neoxt" in BASE_MODEL.lower():
k = k.replace("gpt_neox.model.", "")
else:
k = k.replace("base_model.model.", "")
if k == "model.embed_tokens.weight":
return "tok_embeddings.weight"
elif k == "model.norm.weight":
return "norm.weight"
elif k == "lm_head.weight":
return "output.weight"
elif k.startswith("model.layers."):
layer = k.split(".")[2]
if k.endswith(".self_attn.q_proj.weight"):
return f"layers.{layer}.attention.wq.weight"
elif k.endswith(".self_attn.k_proj.weight"):
return f"layers.{layer}.attention.wk.weight"
elif k.endswith(".self_attn.v_proj.weight"):
return f"layers.{layer}.attention.wv.weight"
elif k.endswith(".self_attn.o_proj.weight"):
return f"layers.{layer}.attention.wo.weight"
elif k.endswith(".mlp.gate_proj.weight"):
return f"layers.{layer}.feed_forward.w1.weight"
elif k.endswith(".mlp.down_proj.weight"):
return f"layers.{layer}.feed_forward.w2.weight"
elif k.endswith(".mlp.up_proj.weight"):
return f"layers.{layer}.feed_forward.w3.weight"
elif k.endswith(".input_layernorm.weight"):
return f"layers.{layer}.attention_norm.weight"
elif k.endswith(".post_attention_layernorm.weight"):
return f"layers.{layer}.ffn_norm.weight"
elif k.endswith("rotary_emb.inv_freq") or "lora" in k:
return None
else:
print(layer, k)
raise NotImplementedError
else:
print(k)
raise NotImplementedError
new_state_dict = {}
for k, v in lora_model_sd.items():
new_k = translate_state_dict_key(k)
if new_k is not None:
if "wq" in new_k or "wk" in new_k:
new_state_dict[new_k] = unpermute(v)
else:
new_state_dict[new_k] = v
os.makedirs("./ckpt", exist_ok=True)
torch.save(new_state_dict, "./ckpt/consolidated.00.pth")
with open("./ckpt/params.json", "w") as f:
json.dump(params, f)
else:
deloreanized_sd = {
k.replace("base_model.model.", ""): v
for k, v in lora_model_sd.items()
if "lora" not in k
}
base_model.config.custom_pipelines = {
"text-generation": {
"impl": "h2oai_pipeline.H2OTextGenerationPipeline",
"pt": "AutoModelForCausalLM",
}
}
PreTrainedModel.save_pretrained(
base_model,
OUTPUT_NAME,
state_dict=deloreanized_sd,
# max_shard_size="5GB",
)
do_copy(OUTPUT_NAME)
test_copy()
def do_copy(OUTPUT_NAME):
dest_file = os.path.join(OUTPUT_NAME, "h2oai_pipeline.py")
shutil.copyfile("src/h2oai_pipeline.py", dest_file)
os.system("""sed -i 's/from enums.*//g' %s""" % dest_file)
os.system("""sed -i 's/from stopping.*//g' %s""" % dest_file)
os.system("""sed -i 's/from prompter.*//g' %s""" % dest_file)
os.system(
"""cat %s|grep -v "from enums import PromptType" >> %s"""
% ("src/enums.py", dest_file)
)
os.system(
"""cat %s|grep -v "from enums import PromptType" >> %s"""
% ("src/prompter.py", dest_file)
)
os.system(
"""cat %s|grep -v "from enums import PromptType" >> %s"""
% ("src/stopping.py", dest_file)
)
TEST_OUTPUT_NAME = "test_output"
def test_copy():
if os.path.isdir(TEST_OUTPUT_NAME):
shutil.rmtree(TEST_OUTPUT_NAME)
os.makedirs(TEST_OUTPUT_NAME, exist_ok=False)
do_copy(TEST_OUTPUT_NAME)
shutil.copy("src/export_hf_checkpoint.py", TEST_OUTPUT_NAME)
os.environ["DO_COPY_TEST"] = "1"
os.chdir(TEST_OUTPUT_NAME)
output = subprocess.check_output(["python", "export_hf_checkpoint.py"])
print(output)
def inner_test_copy():
"""
pytest -s -v export_hf_checkpoint.py::test_copy
:return:
"""
# test imports
# below supposed to look bad in pycharm, don't fix!
from h2oai_pipeline import (
get_stopping,
get_prompt,
H2OTextGenerationPipeline,
)
assert get_stopping
assert get_prompt
assert H2OTextGenerationPipeline
if __name__ == "__main__":
if os.getenv("DO_COPY_TEST"):
inner_test_copy()
else:
do_export()
# uncomment for raw isolated test, but test is done every time for each export now
# test_copy()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,380 @@
import inspect
import os
from functools import partial
from typing import Dict, Any, Optional, List
from langchain.callbacks.manager import CallbackManagerForLLMRun
from pydantic import root_validator
from langchain.llms import gpt4all
from dotenv import dotenv_values
from utils import FakeTokenizer
def get_model_tokenizer_gpt4all(base_model, **kwargs):
# defaults (some of these are generation parameters, so need to be passed in at generation time)
model_kwargs = dict(
n_threads=os.cpu_count() // 2,
temp=kwargs.get("temperature", 0.2),
top_p=kwargs.get("top_p", 0.75),
top_k=kwargs.get("top_k", 40),
n_ctx=2048 - 256,
)
env_gpt4all_file = ".env_gpt4all"
model_kwargs.update(dotenv_values(env_gpt4all_file))
# make int or float if can to satisfy types for class
for k, v in model_kwargs.items():
try:
if float(v) == int(v):
model_kwargs[k] = int(v)
else:
model_kwargs[k] = float(v)
except:
pass
if base_model == "llama":
if "model_path_llama" not in model_kwargs:
raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
model_path = model_kwargs.pop("model_path_llama")
# FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python
from llama_cpp import Llama
# llama sets some things at init model time, not generation time
func_names = list(inspect.signature(Llama.__init__).parameters)
model_kwargs = {
k: v for k, v in model_kwargs.items() if k in func_names
}
model_kwargs["n_ctx"] = int(model_kwargs["n_ctx"])
model = Llama(model_path=model_path, **model_kwargs)
elif base_model in "gpt4all_llama":
if (
"model_name_gpt4all_llama" not in model_kwargs
and "model_path_gpt4all_llama" not in model_kwargs
):
raise ValueError(
"No model_name_gpt4all_llama or model_path_gpt4all_llama in %s"
% env_gpt4all_file
)
model_name = model_kwargs.pop("model_name_gpt4all_llama")
model_type = "llama"
from gpt4all import GPT4All as GPT4AllModel
model = GPT4AllModel(model_name=model_name, model_type=model_type)
elif base_model in "gptj":
if (
"model_name_gptj" not in model_kwargs
and "model_path_gptj" not in model_kwargs
):
raise ValueError(
"No model_name_gpt4j or model_path_gpt4j in %s"
% env_gpt4all_file
)
model_name = model_kwargs.pop("model_name_gptj")
model_type = "gptj"
from gpt4all import GPT4All as GPT4AllModel
model = GPT4AllModel(model_name=model_name, model_type=model_type)
else:
raise ValueError("No such base_model %s" % base_model)
return model, FakeTokenizer(), "cpu"
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
# streaming to std already occurs without this
# sys.stdout.write(token)
# sys.stdout.flush()
pass
def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
# default from class
model_kwargs = {
k: v.default
for k, v in dict(inspect.signature(cls).parameters).items()
if k not in exclude_list
}
# from our defaults
model_kwargs.update(default_kwargs)
# from user defaults
model_kwargs.update(env_kwargs)
# ensure only valid keys
func_names = list(inspect.signature(cls).parameters)
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
return model_kwargs
def get_llm_gpt4all(
model_name,
model=None,
max_new_tokens=256,
temperature=0.1,
repetition_penalty=1.0,
top_k=40,
top_p=0.7,
streaming=False,
callbacks=None,
prompter=None,
verbose=False,
):
assert prompter is not None
env_gpt4all_file = ".env_gpt4all"
env_kwargs = dotenv_values(env_gpt4all_file)
n_ctx = env_kwargs.pop("n_ctx", 2048 - max_new_tokens)
default_kwargs = dict(
context_erase=0.5,
n_batch=1,
n_ctx=n_ctx,
n_predict=max_new_tokens,
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
repeat_penalty=repetition_penalty,
temp=temperature,
temperature=temperature,
top_k=top_k,
top_p=top_p,
use_mlock=True,
verbose=verbose,
)
if model_name == "llama":
cls = H2OLlamaCpp
model_path = (
env_kwargs.pop("model_path_llama") if model is None else model
)
model_kwargs = get_model_kwargs(
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
)
model_kwargs.update(
dict(
model_path=model_path,
callbacks=callbacks,
streaming=streaming,
prompter=prompter,
)
)
llm = cls(**model_kwargs)
llm.client.verbose = verbose
elif model_name == "gpt4all_llama":
cls = H2OGPT4All
model_path = (
env_kwargs.pop("model_path_gpt4all_llama")
if model is None
else model
)
model_kwargs = get_model_kwargs(
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
)
model_kwargs.update(
dict(
model=model_path,
backend="llama",
callbacks=callbacks,
streaming=streaming,
prompter=prompter,
)
)
llm = cls(**model_kwargs)
elif model_name == "gptj":
cls = H2OGPT4All
model_path = (
env_kwargs.pop("model_path_gptj") if model is None else model
)
model_kwargs = get_model_kwargs(
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
)
model_kwargs.update(
dict(
model=model_path,
backend="gptj",
callbacks=callbacks,
streaming=streaming,
prompter=prompter,
)
)
llm = cls(**model_kwargs)
else:
raise RuntimeError("No such model_name %s" % model_name)
return llm
class H2OGPT4All(gpt4all.GPT4All):
model: Any
prompter: Any
"""Path to the pre-trained GPT4All model file."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in the environment."""
try:
if isinstance(values["model"], str):
from gpt4all import GPT4All as GPT4AllModel
full_path = values["model"]
model_path, delimiter, model_name = full_path.rpartition("/")
model_path += delimiter
values["client"] = GPT4AllModel(
model_name=model_name,
model_path=model_path or None,
model_type=values["backend"],
allow_download=False,
)
if values["n_threads"] is not None:
# set n_threads
values["client"].model.set_thread_count(
values["n_threads"]
)
else:
values["client"] = values["model"]
try:
values["backend"] = values["client"].model_type
except AttributeError:
# The below is for compatibility with GPT4All Python bindings <= 0.2.3.
values["backend"] = values["client"].model.model_type
except ImportError:
raise ValueError(
"Could not import gpt4all python package. "
"Please install it with `pip install gpt4all`."
)
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
# Roughly 4 chars per token if natural language
prompt = prompt[-self.n_ctx * 4 :]
# use instruct prompting
data_point = dict(context="", instruction=prompt, input="")
prompt = self.prompter.generate_prompt(data_point)
verbose = False
if verbose:
print("_call prompt: %s" % prompt, flush=True)
# FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
return super()._call(prompt, stop=stop, run_manager=run_manager)
from langchain.llms import LlamaCpp
class H2OLlamaCpp(LlamaCpp):
model_path: Any
prompter: Any
"""Path to the pre-trained GPT4All model file."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
if isinstance(values["model_path"], str):
model_path = values["model_path"]
model_param_names = [
"lora_path",
"lora_base",
"n_ctx",
"n_parts",
"seed",
"f16_kv",
"logits_all",
"vocab_only",
"use_mlock",
"n_threads",
"n_batch",
"use_mmap",
"last_n_tokens_size",
]
model_params = {k: values[k] for k in model_param_names}
# For backwards compatibility, only include if non-null.
if values["n_gpu_layers"] is not None:
model_params["n_gpu_layers"] = values["n_gpu_layers"]
try:
from llama_cpp import Llama
values["client"] = Llama(model_path, **model_params)
except ImportError:
raise ModuleNotFoundError(
"Could not import llama-cpp-python library. "
"Please install the llama-cpp-python library to "
"use this embedding model: pip install llama-cpp-python"
)
except Exception as e:
raise ValueError(
f"Could not load Llama model from path: {model_path}. "
f"Received error {e}"
)
else:
values["client"] = values["model_path"]
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
verbose = False
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
# still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
prompt = prompt[-self.n_ctx * 4 :]
prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
num_prompt_tokens = len(prompt_tokens)
if num_prompt_tokens > self.n_ctx:
# conservative by using int()
chars_per_token = int(len(prompt) / num_prompt_tokens)
prompt = prompt[-self.n_ctx * chars_per_token :]
if verbose:
print(
"reducing tokens, assuming average of %s chars/token: %s"
% chars_per_token,
flush=True,
)
prompt_tokens2 = self.client.tokenize(
b" " + prompt.encode("utf-8")
)
num_prompt_tokens2 = len(prompt_tokens2)
print(
"reduced tokens from %d -> %d"
% (num_prompt_tokens, num_prompt_tokens2),
flush=True,
)
# use instruct prompting
data_point = dict(context="", instruction=prompt, input="")
prompt = self.prompter.generate_prompt(data_point)
if verbose:
print("_call prompt: %s" % prompt, flush=True)
if self.streaming:
text_callback = None
if run_manager:
text_callback = partial(
run_manager.on_llm_new_token, verbose=self.verbose
)
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
if text_callback:
text_callback(prompt)
text = ""
for token in self.stream(
prompt=prompt, stop=stop, run_manager=run_manager
):
text_chunk = token["choices"][0]["text"]
# self.stream already calls text_callback
# if text_callback:
# text_callback(text_chunk)
text += text_chunk
return text
else:
params = self._get_parameters(stop)
params = {**params, **kwargs}
result = self.client(prompt=prompt, **params)
return result["choices"][0]["text"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,225 @@
from __future__ import annotations
from typing import Iterable
from gradio.themes.soft import Soft
from gradio.themes import Color, Size
from gradio.themes.utils import colors, sizes, fonts
h2o_yellow = Color(
name="yellow",
c50="#fffef2",
c100="#fff9e6",
c200="#ffecb3",
c300="#ffe28c",
c400="#ffd659",
c500="#fec925",
c600="#e6ac00",
c700="#bf8f00",
c800="#a67c00",
c900="#664d00",
c950="#403000",
)
h2o_gray = Color(
name="gray",
c50="#f8f8f8",
c100="#e5e5e5",
c200="#cccccc",
c300="#b2b2b2",
c400="#999999",
c500="#7f7f7f",
c600="#666666",
c700="#4c4c4c",
c800="#333333",
c900="#191919",
c950="#0d0d0d",
)
text_xsm = Size(
name="text_xsm",
xxs="4px",
xs="5px",
sm="6px",
md="7px",
lg="8px",
xl="10px",
xxl="12px",
)
spacing_xsm = Size(
name="spacing_xsm",
xxs="1px",
xs="1px",
sm="1px",
md="2px",
lg="3px",
xl="5px",
xxl="7px",
)
radius_xsm = Size(
name="radius_xsm",
xxs="1px",
xs="1px",
sm="1px",
md="2px",
lg="3px",
xl="5px",
xxl="7px",
)
class H2oTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = h2o_yellow,
secondary_hue: colors.Color | str = h2o_yellow,
neutral_hue: colors.Color | str = h2o_gray,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("Montserrat"),
"ui-sans-serif",
"system-ui",
"sans-serif",
),
font_mono: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"Consolas",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
link_text_color="#3344DD",
link_text_color_hover="#3344DD",
link_text_color_visited="#3344DD",
link_text_color_dark="#74abff",
link_text_color_hover_dark="#a3c8ff",
link_text_color_active_dark="#a3c8ff",
link_text_color_visited_dark="#74abff",
button_primary_text_color="*neutral_950",
button_primary_text_color_dark="*neutral_950",
button_primary_background_fill="*primary_500",
button_primary_background_fill_dark="*primary_500",
block_label_background_fill="*primary_500",
block_label_background_fill_dark="*primary_500",
block_label_text_color="*neutral_950",
block_label_text_color_dark="*neutral_950",
block_title_text_color="*neutral_950",
block_title_text_color_dark="*neutral_950",
block_background_fill_dark="*neutral_950",
body_background_fill="*neutral_50",
body_background_fill_dark="*neutral_900",
background_fill_primary_dark="*block_background_fill",
block_radius="0 0 8px 8px",
checkbox_label_text_color_selected_dark="#000000",
)
class SoftTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.indigo,
secondary_hue: colors.Color | str = colors.indigo,
neutral_hue: colors.Color | str = colors.gray,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_md,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("Montserrat"),
"ui-sans-serif",
"system-ui",
"sans-serif",
),
font_mono: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"Consolas",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
h2o_logo = (
'<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"'
' viewBox="0 0 600.28 600.28"><defs><style>.cls-1{fill:#fec925;}.cls-2{fill:#161616;}.cls-3{fill:'
'#54585a;}</style></defs><g id="Fill-1"><rect class="cls-1" width="600.28" height="600.28" '
'rx="23.24"/></g><path class="cls-2" d="M174.33,246.06v92.78H152.86v-38H110.71v38H89.24V246.06h21.'
'47v36.58h42.15V246.06Z"/><path class="cls-2" d="M259.81,321.34v17.5H189.7V324.92l35.78-33.8c8.22-7.'
"82,9.68-12.59,9.68-17.09,0-7.29-5-11.53-14.85-11.53-7.95,0-14.71,3-19.21,9.27L185.46,261.7c7.15-10"
'.47,20.14-17.23,36.84-17.23,20.68,0,34.46,10.6,34.46,27.44,0,9-2.52,17.22-15.51,29.29l-21.33,20.14Z"'
'/><path class="cls-2" d="M268.69,292.45c0-27.57,21.47-48,50.76-48s50.76,20.28,50.76,48-21.6,48-50.'
"76,48S268.69,320,268.69,292.45Zm79.78,0c0-17.63-12.46-29.69-29-29.69s-29,12.06-29,29.69,12.46,29.69"
',29,29.69S348.47,310.08,348.47,292.45Z"/><path class="cls-3" d="M377.23,326.91c0-7.69,5.7-12.73,12.'
'85-12.73s12.86,5,12.86,12.73a12.86,12.86,0,1,1-25.71,0Z"/><path class="cls-3" d="M481.4,298.15v40.'
"69H462.05V330c-3.84,6.49-11.27,9.94-21.74,9.94-16.7,0-26.64-9.28-26.64-21.61,0-12.59,8.88-21.34,30."
"62-21.34h16.43c0-8.87-5.3-14-16.43-14-7.55,0-15.37,2.51-20.54,6.62l-7.43-14.44c7.82-5.57,19.35-8."
"62,30.75-8.62C468.81,266.47,481.4,276.54,481.4,298.15Zm-20.68,18.16V309H446.54c-9.67,0-12.72,3.57-"
'12.72,8.35,0,5.16,4.37,8.61,11.66,8.61C452.37,326,458.34,322.8,460.72,316.31Z"/><path class="cls-3"'
' d="M497.56,246.06c0-6.49,5.17-11.53,12.86-11.53s12.86,4.77,12.86,11.13c0,6.89-5.17,11.93-12.86,'
'11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
)
def get_h2o_title(title, description):
# NOTE: Check full width desktop, smallest width browser desktop, iPhone browsers to ensure no overlap etc.
return f"""<div style="float:left; justify-content:left; height: 80px; width: 195px; margin-top:0px">
{description}
</div>
<div style="display:flex; justify-content:center; margin-bottom:30px; margin-right:330px;">
<div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
<h1 style="line-height:60px">{title}</h1>
</div>
<div style="float:right; height: 80px; width: 80px; margin-top:-100px">
<img src="https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png">
</div>
"""
def get_simple_title(title, description):
return f"""{description}<h1 align="center"> {title}</h1>"""
def get_dark_js():
return """() => {
if (document.querySelectorAll('.dark').length) {
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
} else {
document.querySelector('body').classList.add('dark');
}
}"""

View File

@@ -0,0 +1,53 @@
def get_css(kwargs) -> str:
if kwargs["h2ocolors"]:
css_code = """footer {visibility: hidden;}
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
body.dark{background:linear-gradient(#000000,#0d0d0d);}
"""
else:
css_code = """footer {visibility: hidden}"""
css_code += make_css_base()
return css_code
def make_css_base() -> str:
return """
@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
body.dark{#warning {background-color: #555555};}
#small_btn {
margin: 0.6em 0em 0.55em 0;
max-width: 20em;
min-width: 5em !important;
height: 5em;
font-size: 14px !important;
}
#prompt-form {
border: 1px solid var(--primary-500) !important;
}
#prompt-form.block {
border-radius: var(--block-radius) !important;
}
#prompt-form textarea {
border: 1px solid rgb(209, 213, 219);
}
#prompt-form label > div {
margin-top: 4px;
}
button.primary:hover {
background-color: var(--primary-600) !important;
transition: .2s;
}
#prompt-form-area {
margin-bottom: 2.5rem;
}
.chatsmall chatbot {font-size: 10px !important}
"""

View File

@@ -0,0 +1,93 @@
import traceback
from typing import Callable
import os
from gradio_client.client import Job
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
from gradio_client import Client
class GradioClient(Client):
"""
Parent class of gradio client
To handle automatically refreshing client if detect gradio server changed
"""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
super().__init__(*args, **kwargs)
self.server_hash = self.get_server_hash()
def get_server_hash(self):
"""
Get server hash using super without any refresh action triggered
Returns: git hash of gradio server
"""
return super().submit(api_name="/system_hash").result()
def refresh_client_if_should(self):
# get current hash in order to update api_name -> fn_index map in case gradio server changed
# FIXME: Could add cli api as hash
server_hash = self.get_server_hash()
if self.server_hash != server_hash:
self.refresh_client()
self.server_hash = server_hash
else:
self.reset_session()
def refresh_client(self):
"""
Ensure every client call is independent
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
Returns:
"""
# need session hash to be new every time, to avoid "generator already executing"
self.reset_session()
client = Client(*self.args, **self.kwargs)
for k, v in client.__dict__.items():
setattr(self, k, v)
def submit(
self,
*args,
api_name: str | None = None,
fn_index: int | None = None,
result_callbacks: Callable | list[Callable] | None = None,
) -> Job:
# Note predict calls submit
try:
self.refresh_client_if_should()
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
except Exception as e:
print("Hit e=%s" % str(e), flush=True)
# force reconfig in case only that
self.refresh_client()
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
# see if immediately failed
e = job.future._exception
if e is not None:
print(
"GR job failed: %s %s"
% (str(e), "".join(traceback.format_tb(e.__traceback__))),
flush=True,
)
# force reconfig in case only that
self.refresh_client()
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
e2 = job.future._exception
if e2 is not None:
print(
"GR job failed again: %s\n%s"
% (
str(e2),
"".join(traceback.format_tb(e2.__traceback__)),
),
flush=True,
)
return job

View File

@@ -0,0 +1,185 @@
import os
import math
import gradio as gr
def make_chatbots(output_label0, output_label0_model2, **kwargs):
text_outputs = []
chat_kwargs = []
for model_state_lock in kwargs["model_states"]:
if os.environ.get("DEBUG_MODEL_LOCK"):
model_name = (
model_state_lock["base_model"]
+ " : "
+ model_state_lock["inference_server"]
)
else:
model_name = model_state_lock["base_model"]
output_label = f"h2oGPT [{model_name}]"
min_width = (
250
if kwargs["gradio_size"] in ["small", "large", "medium"]
else 160
)
chat_kwargs.append(
dict(
label=output_label,
visible=kwargs["model_lock"],
elem_classes="chatsmall",
height=kwargs["height"] or 400,
min_width=min_width,
)
)
if kwargs["model_lock_columns"] == -1:
kwargs["model_lock_columns"] = len(kwargs["model_states"])
if kwargs["model_lock_columns"] is None:
kwargs["model_lock_columns"] = 3
ncols = kwargs["model_lock_columns"]
if kwargs["model_states"] == 0:
nrows = 0
else:
nrows = math.ceil(
len(kwargs["model_states"]) / kwargs["model_lock_columns"]
)
if kwargs["model_lock_columns"] == 0:
# not using model_lock
pass
elif nrows <= 1:
with gr.Row():
for chat_kwargs1, model_state_lock in zip(
chat_kwargs, kwargs["model_states"]
):
text_outputs.append(gr.Chatbot(**chat_kwargs1))
elif nrows == kwargs["model_states"]:
with gr.Row():
for chat_kwargs1, model_state_lock in zip(
chat_kwargs, kwargs["model_states"]
):
text_outputs.append(gr.Chatbot(**chat_kwargs1))
elif nrows == 2:
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if mii >= len(kwargs["model_states"]) / 2:
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if mii < len(kwargs["model_states"]) / 2:
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
elif nrows == 3:
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if mii >= 1 * len(kwargs["model_states"]) / 3:
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if (
mii < 1 * len(kwargs["model_states"]) / 3
or mii >= 2 * len(kwargs["model_states"]) / 3
):
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if mii < 2 * len(kwargs["model_states"]) / 3:
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
elif nrows >= 4:
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if mii >= 1 * len(kwargs["model_states"]) / 4:
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if (
mii < 1 * len(kwargs["model_states"]) / 4
or mii >= 2 * len(kwargs["model_states"]) / 4
):
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if (
mii < 2 * len(kwargs["model_states"]) / 4
or mii >= 3 * len(kwargs["model_states"]) / 4
):
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if mii < 3 * len(kwargs["model_states"]) / 4:
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
text_output = gr.Chatbot(
label=output_label0,
visible=not kwargs["model_lock"],
height=kwargs["height"] or 400,
)
text_output2 = gr.Chatbot(
label=output_label0_model2,
visible=False and not kwargs["model_lock"],
height=kwargs["height"] or 400,
)
return text_output, text_output2, text_outputs
def make_prompt_form(kwargs, LangChainMode):
if kwargs["langchain_mode"] != LangChainMode.DISABLED.value:
extra_prompt_form = ". For summarization, empty submission uses first top_k_docs documents."
else:
extra_prompt_form = ""
if kwargs["input_lines"] > 1:
instruction_label = (
"Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
)
else:
instruction_label = (
"Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
)
with gr.Row(): # elem_id='prompt-form-area'):
with gr.Column(scale=50):
instruction = gr.Textbox(
lines=kwargs["input_lines"],
label="Ask anything",
placeholder=instruction_label,
info=None,
elem_id="prompt-form",
container=True,
)
with gr.Row():
submit = gr.Button(
value="Submit", variant="primary", scale=0, size="sm"
)
stop_btn = gr.Button(
value="Stop", variant="secondary", scale=0, size="sm"
)
return instruction, submit, stop_btn

View File

@@ -0,0 +1,622 @@
import os
from apps.stable_diffusion.src.utils.utils import _compile_module
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
from stopping import get_stopping
from prompter import Prompter, PromptType
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
StoppingCriteriaList,
)
import copy
import torch
from transformers import AutoConfig, AutoModelForCausalLM
import gc
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from apps.stable_diffusion.src import args
global_device = "cuda"
global_precision = "fp16"
if not args.run_docuchat_web:
args.device = global_device
args.precision = global_precision
class H2OGPTSHARKModel(torch.nn.Module):
def __init__(self):
super().__init__()
model_name = "h2ogpt_falcon_7b"
extended_model_name = (
model_name + "_" + args.precision + "_" + args.device
)
vmfb_path = Path(extended_model_name + ".vmfb")
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
shark_module = None
if not vmfb_path.exists():
if args.device in ["cuda", "cpu"] and args.precision in [
"fp16",
"fp32",
]:
# Downloading VMFB from shark_tank
print("Downloading vmfb from shark tank.")
download_public_file(
"gs://shark_tank/langchain/" + str(vmfb_path),
vmfb_path.absolute(),
single_file=True,
)
else:
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/langchain/" + str(mlir_path),
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
raise ValueError(
f"MLIR not found at {mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
shark_module = SharkInference(
mlir_module=bytecode,
device=args.device,
mlir_dialect="linalg",
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(
shark_module, extended_model_name, []
)
print("Saved newly generated vmfb.")
if shark_module is None:
if vmfb_path.exists():
print("Compiled vmfb found. Loading it from: ", vmfb_path)
shark_module = SharkInference(
None, device=args.device, mlir_dialect="linalg"
)
shark_module.load_module(str(vmfb_path))
print("Compiled vmfb loaded successfully.")
else:
raise ValueError("Unable to download/generate a vmfb.")
self.model = shark_module
def forward(self, input_ids, attention_mask):
result = torch.from_numpy(
self.model(
"forward",
(input_ids.to(device="cpu"), attention_mask.to(device="cpu")),
)
).to(device=args.device)
return result
h2ogpt_model = H2OGPTSHARKModel()
def pad_or_truncate_inputs(
input_ids, attention_mask, max_padding_length=400, do_truncation=False
):
inp_shape = input_ids.shape
if inp_shape[1] < max_padding_length:
# do padding
num_add_token = max_padding_length - inp_shape[1]
padded_input_ids = torch.cat(
[
torch.tensor([[11] * num_add_token]).to(device=args.device),
input_ids,
],
dim=1,
)
padded_attention_mask = torch.cat(
[
torch.tensor([[0] * num_add_token]).to(device=args.device),
attention_mask,
],
dim=1,
)
return padded_input_ids, padded_attention_mask
elif inp_shape[1] > max_padding_length or do_truncation:
# do truncation
num_remove_token = inp_shape[1] - max_padding_length
truncated_input_ids = input_ids[:, num_remove_token:]
truncated_attention_mask = attention_mask[:, num_remove_token:]
return truncated_input_ids, truncated_attention_mask
else:
return input_ids, attention_mask
class H2OTextGenerationPipeline(TextGenerationPipeline):
def __init__(
self,
*args,
debug=False,
chat=False,
stream_output=False,
sanitize_bot_response=False,
use_prompter=True,
prompter=None,
prompt_type=None,
prompt_dict=None,
max_input_tokens=2048 - 256,
**kwargs,
):
"""
HF-like pipeline, but handle instruction prompting and stopping (for some models)
:param args:
:param debug:
:param chat:
:param stream_output:
:param sanitize_bot_response:
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
:param prompter: prompter, can pass if have already
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
If use_prompter, then will make prompter and use it.
:param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
:param max_input_tokens:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self.prompt_text = None
self.use_prompter = use_prompter
self.prompt_type = prompt_type
self.prompt_dict = prompt_dict
self.prompter = prompter
if self.use_prompter:
if self.prompter is not None:
assert self.prompter.prompt_type is not None
else:
self.prompter = Prompter(
self.prompt_type,
self.prompt_dict,
debug=debug,
chat=chat,
stream_output=stream_output,
)
self.human = self.prompter.humanstr
self.bot = self.prompter.botstr
self.can_stop = True
else:
self.prompter = None
self.human = None
self.bot = None
self.can_stop = False
self.sanitize_bot_response = sanitize_bot_response
self.max_input_tokens = (
max_input_tokens # not for generate, so ok that not kwargs
)
@staticmethod
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
verbose = bool(int(os.getenv("VERBOSE_PIPELINE", "0")))
if hasattr(tokenizer, "model_max_length"):
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
model_max_length = tokenizer.model_max_length
if max_prompt_length is not None:
model_max_length = min(model_max_length, max_prompt_length)
# cut at some upper likely limit to avoid excessive tokenization etc
# upper bound of 10 chars/token, e.g. special chars sometimes are long
if len(prompt_text) > model_max_length * 10:
len0 = len(prompt_text)
prompt_text = prompt_text[-model_max_length * 10 :]
if verbose:
print(
"Cut of input: %s -> %s" % (len0, len(prompt_text)),
flush=True,
)
else:
# unknown
model_max_length = None
num_prompt_tokens = None
if model_max_length is not None:
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
# For https://github.com/h2oai/h2ogpt/issues/192
for trial in range(0, 3):
prompt_tokens = tokenizer(prompt_text)["input_ids"]
num_prompt_tokens = len(prompt_tokens)
if num_prompt_tokens > model_max_length:
# conservative by using int()
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
# keep tail, where question is if using langchain
prompt_text = prompt_text[
-model_max_length * chars_per_token :
]
if verbose:
print(
"reducing %s tokens, assuming average of %s chars/token for %s characters"
% (
num_prompt_tokens,
chars_per_token,
len(prompt_text),
),
flush=True,
)
else:
if verbose:
print(
"using %s tokens with %s chars"
% (num_prompt_tokens, len(prompt_text)),
flush=True,
)
break
return prompt_text, num_prompt_tokens
def preprocess(
self,
prompt_text,
prefix="",
handle_long_generation=None,
**generate_kwargs,
):
(
prompt_text,
num_prompt_tokens,
) = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
data_point = dict(context="", instruction=prompt_text, input="")
if self.prompter is not None:
prompt_text = self.prompter.generate_prompt(data_point)
self.prompt_text = prompt_text
if handle_long_generation is None:
# forces truncation of inputs to avoid critical failure
handle_long_generation = None # disable with new approaches
return super().preprocess(
prompt_text,
prefix=prefix,
handle_long_generation=handle_long_generation,
**generate_kwargs,
)
def postprocess(
self,
model_outputs,
return_type=ReturnType.FULL_TEXT,
clean_up_tokenization_spaces=True,
):
records = super().postprocess(
model_outputs,
return_type=return_type,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
for rec in records:
if self.use_prompter:
outputs = rec["generated_text"]
outputs = self.prompter.get_response(
outputs,
prompt=self.prompt_text,
sanitize_bot_response=self.sanitize_bot_response,
)
elif self.bot and self.human:
outputs = (
rec["generated_text"]
.split(self.bot)[1]
.split(self.human)[0]
)
else:
outputs = rec["generated_text"]
rec["generated_text"] = outputs
print(
"prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs),
flush=True,
)
return records
def generate_new_token(self):
model_inputs = self.model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = h2ogpt_model.forward(
model_inputs["input_ids"], model_inputs["attention_mask"]
)
if args.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.truncated_input_ids.append(self.input_ids[:, 0])
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
def generate_token(self, **generate_kwargs):
self.truncated_input_ids = []
generation_config_ = GenerationConfig.from_model_config(
self.model.config
)
generation_config = copy.deepcopy(generation_config_)
self.model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
self.stopping_criteria = (
self.stopping_criteria
if self.stopping_criteria is not None
else StoppingCriteriaList()
)
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
self.model_kwargs,
) = self.model._prepare_model_inputs(
None, generation_config.bos_token_id, self.model_kwargs
)
batch_size = inputs_tensor.shape[0]
self.model_kwargs[
"output_attentions"
] = generation_config.output_attentions
self.model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
self.model_kwargs["use_cache"] = generation_config.use_cache
self.input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else self.model_kwargs.pop("input_ids")
)
input_ids_seq_length = self.input_ids.shape[-1]
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
self.logits_processor = self.model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=self.stopping_criteria,
)
self.logits_warper = self.model._get_logits_warper(generation_config)
(
self.input_ids,
self.model_kwargs,
) = self.model._expand_inputs_for_generation(
input_ids=self.input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.model.config.is_encoder_decoder, # False
**self.model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id).to(device=args.device)
if eos_token_id is not None
else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
output_attentions = generation_config.output_attentions # False
output_hidden_states = generation_config.output_hidden_states # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
self.input_ids.shape[0],
dtype=torch.long,
device=self.input_ids.device,
)
timesRan = 0
import time
start = time.time()
print("\n")
while True:
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
print(f"{new_word}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(self.input_ids, self.scores)
):
break
timesRan = timesRan + 1
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / timesRan
)
)
self.input_ids = torch.cat(
[
torch.tensor(self.truncated_input_ids)
.to(device=args.device)
.unsqueeze(dim=0),
self.input_ids,
],
dim=-1,
)
torch.cuda.empty_cache()
gc.collect()
return self.input_ids
def _forward(self, model_inputs, **generate_kwargs):
if self.can_stop:
stopping_criteria = get_stopping(
self.prompt_type,
self.prompt_dict,
self.tokenizer,
self.device,
human=self.human,
bot=self.bot,
model_max_length=self.tokenizer.model_max_length,
)
generate_kwargs["stopping_criteria"] = stopping_criteria
# return super()._forward(model_inputs, **generate_kwargs)
return self.__forward(model_inputs, **generate_kwargs)
# FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
# FIXME: https://github.com/h2oai/h2ogpt/issues/172
def __forward(self, model_inputs, **generate_kwargs):
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text")
## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
# generate_kwargs = copy.deepcopy(generate_kwargs)
prefix_length = generate_kwargs.pop("prefix_length", 0)
if prefix_length > 0:
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].max_new_tokens
is not None
)
if not has_max_new_tokens:
generate_kwargs["max_length"] = (
generate_kwargs.get("max_length")
or self.model.config.max_length
)
generate_kwargs["max_length"] += prefix_length
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].min_new_tokens
is not None
)
if not has_min_new_tokens and "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
# BS x SL
# pad or truncate the input_ids and attention_mask
max_padding_length = 400
input_ids, attention_mask = pad_or_truncate_inputs(
input_ids, attention_mask, max_padding_length=max_padding_length
)
self.stopping_criteria = generate_kwargs["stopping_criteria"]
generated_sequence = self.generate_token(
input_ids=input_ids,
attention_mask=attention_mask,
**generate_kwargs,
)
out_b = generated_sequence.shape[0]
generated_sequence = generated_sequence.reshape(
in_b, out_b // in_b, *generated_sequence.shape[1:]
)
return {
"generated_sequence": generated_sequence,
"input_ids": input_ids,
"prompt_text": prompt_text,
}

View File

@@ -0,0 +1,247 @@
"""
Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py
But accepts preloaded model to avoid slowness in use and CUDA forking issues
Loader that loads image captions
By default, the loader utilizes the pre-trained BLIP image captioning model.
https://huggingface.co/Salesforce/blip-image-captioning-base
"""
from typing import List, Union, Any, Tuple
import requests
from langchain.docstore.document import Document
from langchain.document_loaders import ImageCaptionLoader
from utils import get_device, NullContext
import pkg_resources
try:
assert pkg_resources.get_distribution("bitsandbytes") is not None
have_bitsandbytes = True
except (pkg_resources.DistributionNotFound, AssertionError):
have_bitsandbytes = False
class H2OImageCaptionLoader(ImageCaptionLoader):
"""Loader that loads the captions of an image"""
def __init__(
self,
path_images: Union[str, List[str]] = None,
blip_processor: str = None,
blip_model: str = None,
caption_gpu=True,
load_in_8bit=True,
# True doesn't seem to work, even though https://huggingface.co/Salesforce/blip2-flan-t5-xxl#in-8-bit-precision-int8
load_half=False,
load_gptq="",
use_safetensors=False,
min_new_tokens=20,
max_tokens=50,
):
if blip_model is None or blip_model is None:
blip_processor = "Salesforce/blip-image-captioning-base"
blip_model = "Salesforce/blip-image-captioning-base"
super().__init__(path_images, blip_processor, blip_model)
self.blip_processor = blip_processor
self.blip_model = blip_model
self.processor = None
self.model = None
self.caption_gpu = caption_gpu
self.context_class = NullContext
self.device = "cpu"
self.load_in_8bit = (
load_in_8bit and have_bitsandbytes
) # only for blip2
self.load_half = load_half
self.load_gptq = load_gptq
self.use_safetensors = use_safetensors
self.gpu_id = "auto"
# default prompt
self.prompt = "image of"
self.min_new_tokens = min_new_tokens
self.max_tokens = max_tokens
def set_context(self):
if get_device() == "cuda" and self.caption_gpu:
import torch
n_gpus = (
torch.cuda.device_count() if torch.cuda.is_available else 0
)
if n_gpus > 0:
self.context_class = torch.device
self.device = "cuda"
def load_model(self):
try:
import transformers
except ImportError:
raise ValueError(
"`transformers` package not found, please install with "
"`pip install transformers`."
)
self.set_context()
if self.caption_gpu:
if self.gpu_id == "auto":
# blip2 has issues with multi-GPU. Error says need to somehow set language model in device map
# device_map = 'auto'
device_map = {"": 0}
else:
if self.device == "cuda":
device_map = {"": self.gpu_id}
else:
device_map = {"": "cpu"}
else:
device_map = {"": "cpu"}
import torch
with torch.no_grad():
with self.context_class(self.device):
context_class_cast = (
NullContext if self.device == "cpu" else torch.autocast
)
with context_class_cast(self.device):
if "blip2" in self.blip_processor.lower():
from transformers import (
Blip2Processor,
Blip2ForConditionalGeneration,
)
if self.load_half and not self.load_in_8bit:
self.processor = Blip2Processor.from_pretrained(
self.blip_processor, device_map=device_map
).half()
self.model = (
Blip2ForConditionalGeneration.from_pretrained(
self.blip_model, device_map=device_map
).half()
)
else:
self.processor = Blip2Processor.from_pretrained(
self.blip_processor,
load_in_8bit=self.load_in_8bit,
device_map=device_map,
)
self.model = (
Blip2ForConditionalGeneration.from_pretrained(
self.blip_model,
load_in_8bit=self.load_in_8bit,
device_map=device_map,
)
)
else:
from transformers import (
BlipForConditionalGeneration,
BlipProcessor,
)
self.load_half = False # not supported
if self.caption_gpu:
if device_map == "auto":
# Blip doesn't support device_map='auto'
if self.device == "cuda":
if self.gpu_id == "auto":
device_map = {"": 0}
else:
device_map = {"": self.gpu_id}
else:
device_map = {"": "cpu"}
else:
device_map = {"": "cpu"}
self.processor = BlipProcessor.from_pretrained(
self.blip_processor, device_map=device_map
)
self.model = (
BlipForConditionalGeneration.from_pretrained(
self.blip_model, device_map=device_map
)
)
return self
def set_image_paths(self, path_images: Union[str, List[str]]):
"""
Load from a list of image files
"""
if isinstance(path_images, str):
self.image_paths = [path_images]
else:
self.image_paths = path_images
def load(self, prompt=None) -> List[Document]:
if self.processor is None or self.model is None:
self.load_model()
results = []
for path_image in self.image_paths:
caption, metadata = self._get_captions_and_metadata(
model=self.model,
processor=self.processor,
path_image=path_image,
prompt=prompt,
)
doc = Document(page_content=caption, metadata=metadata)
results.append(doc)
return results
def _get_captions_and_metadata(
self, model: Any, processor: Any, path_image: str, prompt=None
) -> Tuple[str, dict]:
"""
Helper function for getting the captions and metadata of an image
"""
if prompt is None:
prompt = self.prompt
try:
from PIL import Image
except ImportError:
raise ValueError(
"`PIL` package not found, please install with `pip install pillow`"
)
try:
if path_image.startswith("http://") or path_image.startswith(
"https://"
):
image = Image.open(
requests.get(path_image, stream=True).raw
).convert("RGB")
else:
image = Image.open(path_image).convert("RGB")
except Exception:
raise ValueError(f"Could not get image data for {path_image}")
import torch
with torch.no_grad():
with self.context_class(self.device):
context_class_cast = (
NullContext if self.device == "cpu" else torch.autocast
)
with context_class_cast(self.device):
if self.load_half:
inputs = processor(
image, prompt, return_tensors="pt"
).half()
else:
inputs = processor(image, prompt, return_tensors="pt")
min_length = len(prompt) // 4 + self.min_new_tokens
self.max_tokens = max(self.max_tokens, min_length)
output = model.generate(
**inputs,
min_length=min_length,
max_length=self.max_tokens,
)
caption: str = processor.decode(
output[0], skip_special_tokens=True
)
prompti = caption.find(prompt)
if prompti >= 0:
caption = caption[prompti + len(prompt) :]
metadata: dict = {"image_path": path_image}
return caption, metadata

View File

@@ -0,0 +1,106 @@
# for generate (gradio server) and finetune
datasets==2.13.0
sentencepiece==0.1.99
gradio==3.35.2
huggingface_hub==0.15.1
appdirs==1.4.4
fire==0.5.0
docutils==0.20.1
evaluate==0.4.0
rouge_score==0.1.2
sacrebleu==2.3.1
scikit-learn==1.2.2
alt-profanity-check==1.2.2
better-profanity==0.7.0
numpy==1.24.3
pandas==2.0.2
matplotlib==3.7.1
loralib==0.1.1
bitsandbytes==0.39.0
accelerate==0.20.3
git+https://github.com/huggingface/peft.git@0b62b4378b4ce9367932c73540349da9a41bdea8
tokenizers==0.13.3
APScheduler==3.10.1
# optional for generate
pynvml==11.5.0
psutil==5.9.5
boto3==1.26.101
botocore==1.29.101
# optional for finetune
tensorboard==2.13.0
neptune==1.2.0
# for gradio client
gradio_client==0.2.7
beautifulsoup4==4.12.2
markdown==3.4.3
# data and testing
pytest==7.2.2
pytest-xdist==3.2.1
nltk==3.8.1
textstat==0.7.3
# pandoc==2.3
#pypandoc==1.11
pypandoc_binary==1.11
openpyxl==3.1.2
lm_dataformat==0.0.20
bioc==2.0
# falcon
einops==0.6.1
instructorembedding==1.0.1
# for gpt4all .env file, but avoid worrying about imports
python-dotenv==1.0.0
text-generation==0.6.0
# for tokenization when don't have HF tokenizer
tiktoken==0.4.0
# optional: for OpenAI endpoint or embeddings (requires key)
openai==0.27.8
# optional for chat with PDF
langchain==0.0.235
pypdf==3.12.2
# avoid textract, requires old six
#textract==1.6.5
# for HF embeddings
sentence_transformers==2.2.2
# local vector db
chromadb==0.3.25
# server vector db
#pymilvus==2.2.8
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
# unstructured==0.8.1
# strong support for images
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
unstructured[local-inference]==0.7.4
#pdf2image==1.16.3
#pytesseract==0.3.10
pillow
pdfminer.six==20221105
urllib3
requests_file
#pdf2image==1.16.3
#pytesseract==0.3.10
tabulate==0.9.0
# FYI pandoc already part of requirements.txt
# JSONLoader, but makes some trouble for some users
# jq==1.4.1
# to check licenses
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
pip-licenses==4.3.0
# weaviate vector db
weaviate-client==3.22.1

View File

@@ -0,0 +1,124 @@
from typing import List, Optional, Tuple
import torch
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[
torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]
]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0,
(bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=qkv.device,
)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
print(
"Replacing original LLaMa attention with flash attention", flush=True
)
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@@ -0,0 +1,109 @@
import functools
def get_loaders(model_name, reward_type, llama_type=None, load_gptq=""):
# NOTE: Some models need specific new prompt_type
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
if load_gptq:
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM
use_triton = False
functools.partial(
AutoGPTQForCausalLM.from_quantized,
quantize_config=None,
use_triton=use_triton,
)
return AutoGPTQForCausalLM.from_quantized, AutoTokenizer
if llama_type is None:
llama_type = "llama" in model_name.lower()
if llama_type:
from transformers import LlamaForCausalLM, LlamaTokenizer
return LlamaForCausalLM.from_pretrained, LlamaTokenizer
elif "distilgpt2" in model_name.lower():
from transformers import AutoModelForCausalLM, AutoTokenizer
return AutoModelForCausalLM.from_pretrained, AutoTokenizer
elif "gpt2" in model_name.lower():
from transformers import GPT2LMHeadModel, GPT2Tokenizer
return GPT2LMHeadModel.from_pretrained, GPT2Tokenizer
elif "mbart-" in model_name.lower():
from transformers import (
MBartForConditionalGeneration,
MBart50TokenizerFast,
)
return (
MBartForConditionalGeneration.from_pretrained,
MBart50TokenizerFast,
)
elif (
"t5" == model_name.lower()
or "t5-" in model_name.lower()
or "flan-" in model_name.lower()
):
from transformers import AutoTokenizer, T5ForConditionalGeneration
return T5ForConditionalGeneration.from_pretrained, AutoTokenizer
elif "bigbird" in model_name:
from transformers import (
BigBirdPegasusForConditionalGeneration,
AutoTokenizer,
)
return (
BigBirdPegasusForConditionalGeneration.from_pretrained,
AutoTokenizer,
)
elif (
"bart-large-cnn-samsum" in model_name
or "flan-t5-base-samsum" in model_name
):
from transformers import pipeline
return pipeline, "summarization"
elif (
reward_type
or "OpenAssistant/reward-model".lower() in model_name.lower()
):
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
)
return (
AutoModelForSequenceClassification.from_pretrained,
AutoTokenizer,
)
else:
from transformers import AutoTokenizer, AutoModelForCausalLM
model_loader = AutoModelForCausalLM
tokenizer_loader = AutoTokenizer
return model_loader.from_pretrained, tokenizer_loader
def get_tokenizer(
tokenizer_loader,
tokenizer_base_model,
local_files_only,
resume_download,
use_auth_token,
):
tokenizer = tokenizer_loader.from_pretrained(
tokenizer_base_model,
local_files_only=local_files_only,
resume_download=resume_download,
use_auth_token=use_auth_token,
padding_side="left",
)
tokenizer.pad_token_id = 0 # different from the eos token
# when generating, we will use the logits of right-most token to predict the next token
# so the padding should be on the left,
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
tokenizer.padding_side = "left" # Allow batched inference
return tokenizer

View File

@@ -0,0 +1,208 @@
import os
import fire
from gpt_langchain import (
path_to_docs,
get_some_dbs_from_hf,
all_db_zips,
some_db_zips,
create_or_update_db,
)
from utils import get_ngpus_vis
def glob_to_db(
user_path,
chunk=True,
chunk_size=512,
verbose=False,
fail_any_exception=False,
n_jobs=-1,
url=None,
enable_captions=True,
captions_model=None,
caption_loader=None,
enable_ocr=False,
):
sources1 = path_to_docs(
user_path,
verbose=verbose,
fail_any_exception=fail_any_exception,
n_jobs=n_jobs,
chunk=chunk,
chunk_size=chunk_size,
url=url,
enable_captions=enable_captions,
captions_model=captions_model,
caption_loader=caption_loader,
enable_ocr=enable_ocr,
)
return sources1
def make_db_main(
use_openai_embedding: bool = False,
hf_embedding_model: str = None,
persist_directory: str = "db_dir_UserData",
user_path: str = "user_path",
url: str = None,
add_if_exists: bool = True,
collection_name: str = "UserData",
verbose: bool = False,
chunk: bool = True,
chunk_size: int = 512,
fail_any_exception: bool = False,
download_all: bool = False,
download_some: bool = False,
download_one: str = None,
download_dest: str = "./",
n_jobs: int = -1,
enable_captions: bool = True,
captions_model: str = "Salesforce/blip-image-captioning-base",
pre_load_caption_model: bool = False,
caption_gpu: bool = True,
enable_ocr: bool = False,
db_type: str = "chroma",
):
"""
# To make UserData db for generate.py, put pdfs, etc. into path user_path and run:
python make_db.py
# once db is made, can use in generate.py like:
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b --langchain_mode=UserData
or zip-up the db_dir_UserData and share:
zip -r db_dir_UserData.zip db_dir_UserData
# To get all db files (except large wiki_full) do:
python make_db.py --download_some=True
# To get a single db file from HF:
python make_db.py --download_one=db_dir_DriverlessAI_docs.zip
:param use_openai_embedding: Whether to use OpenAI embedding
:param hf_embedding_model: HF embedding model to use. Like generate.py, uses 'hkunlp/instructor-large' if have GPUs, else "sentence-transformers/all-MiniLM-L6-v2"
:param persist_directory: where to persist db
:param user_path: where to pull documents from (None means url is not None. If url is not None, this is ignored.)
:param url: url to generate documents from (None means user_path is not None)
:param add_if_exists: Add to db if already exists, but will not add duplicate sources
:param collection_name: Collection name for new db if not adding
:param verbose: whether to show verbose messages
:param chunk: whether to chunk data
:param chunk_size: chunk size for chunking
:param fail_any_exception: whether to fail if any exception hit during ingestion of files
:param download_all: whether to download all (including 23GB Wikipedia) example databases from h2o.ai HF
:param download_some: whether to download some small example databases from h2o.ai HF
:param download_one: whether to download one chosen example databases from h2o.ai HF
:param download_dest: Destination for downloads
:param n_jobs: Number of cores to use for ingesting multiple files
:param enable_captions: Whether to enable captions on images
:param captions_model: See generate.py
:param pre_load_caption_model: See generate.py
:param caption_gpu: Caption images on GPU if present
:param enable_ocr: Whether to enable OCR on images
:param db_type: Type of db to create. Currently only 'chroma' and 'weaviate' is supported.
:return: None
"""
db = None
# match behavior of main() in generate.py for non-HF case
n_gpus = get_ngpus_vis()
if n_gpus == 0:
if hf_embedding_model is None:
# if no GPUs, use simpler embedding model to avoid cost in time
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
else:
if hf_embedding_model is None:
# if still None, then set default
hf_embedding_model = "hkunlp/instructor-large"
if download_all:
print("Downloading all (and unzipping): %s" % all_db_zips, flush=True)
get_some_dbs_from_hf(download_dest, db_zips=all_db_zips)
if verbose:
print("DONE", flush=True)
return db, collection_name
elif download_some:
print(
"Downloading some (and unzipping): %s" % some_db_zips, flush=True
)
get_some_dbs_from_hf(download_dest, db_zips=some_db_zips)
if verbose:
print("DONE", flush=True)
return db, collection_name
elif download_one:
print("Downloading %s (and unzipping)" % download_one, flush=True)
get_some_dbs_from_hf(
download_dest, db_zips=[[download_one, "", "Unknown License"]]
)
if verbose:
print("DONE", flush=True)
return db, collection_name
if enable_captions and pre_load_caption_model:
# preload, else can be too slow or if on GPU have cuda context issues
# Inside ingestion, this will disable parallel loading of multiple other kinds of docs
# However, if have many images, all those images will be handled more quickly by preloaded model on GPU
from image_captions import H2OImageCaptionLoader
caption_loader = H2OImageCaptionLoader(
None,
blip_model=captions_model,
blip_processor=captions_model,
caption_gpu=caption_gpu,
).load_model()
else:
if enable_captions:
caption_loader = "gpu" if caption_gpu else "cpu"
else:
caption_loader = False
if verbose:
print("Getting sources", flush=True)
assert (
user_path is not None or url is not None
), "Can't have both user_path and url as None"
if not url:
assert os.path.isdir(user_path), (
"user_path=%s does not exist" % user_path
)
sources = glob_to_db(
user_path,
chunk=chunk,
chunk_size=chunk_size,
verbose=verbose,
fail_any_exception=fail_any_exception,
n_jobs=n_jobs,
url=url,
enable_captions=enable_captions,
captions_model=captions_model,
caption_loader=caption_loader,
enable_ocr=enable_ocr,
)
exceptions = [x for x in sources if x.metadata.get("exception")]
print("Exceptions: %s" % exceptions, flush=True)
sources = [x for x in sources if "exception" not in x.metadata]
assert len(sources) > 0, "No sources found"
db = create_or_update_db(
db_type,
persist_directory,
collection_name,
sources,
use_openai_embedding,
add_if_exists,
verbose,
hf_embedding_model,
)
assert db is not None
if verbose:
print("DONE", flush=True)
return db, collection_name
if __name__ == "__main__":
fire.Fire(make_db_main)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,403 @@
"""Load Data from a MediaWiki dump xml."""
import ast
import glob
import pickle
import uuid
from typing import List, Optional
import os
import bz2
import csv
import numpy as np
import pandas as pd
import pytest
from matplotlib import pyplot as plt
from langchain.docstore.document import Document
from langchain.document_loaders import MWDumpLoader
# path where downloaded wiki files exist, to be processed
root_path = "/data/jon/h2o-llm"
def unescape(x):
try:
x = ast.literal_eval(x)
except:
try:
x = x.encode("ascii", "ignore").decode("unicode_escape")
except:
pass
return x
def get_views():
# views = pd.read_csv('wiki_page_views_more_1000month.csv')
views = pd.read_csv("wiki_page_views_more_5000month.csv")
views.index = views["title"]
views = views["views"]
views = views.to_dict()
views = {str(unescape(str(k))): v for k, v in views.items()}
views2 = {k.replace("_", " "): v for k, v in views.items()}
# views has _ but pages has " "
views.update(views2)
return views
class MWDumpDirectLoader(MWDumpLoader):
def __init__(
self,
data: str,
encoding: Optional[str] = "utf8",
title_words_limit=None,
use_views=True,
verbose=True,
):
"""Initialize with file path."""
self.data = data
self.encoding = encoding
self.title_words_limit = title_words_limit
self.verbose = verbose
if use_views:
# self.views = get_views()
# faster to use global shared values
self.views = global_views
else:
self.views = None
def load(self) -> List[Document]:
"""Load from file path."""
import mwparserfromhell
import mwxml
dump = mwxml.Dump.from_page_xml(self.data)
docs = []
for page in dump.pages:
if self.views is not None and page.title not in self.views:
if self.verbose:
print("Skipped %s low views" % page.title, flush=True)
continue
for revision in page:
if self.title_words_limit is not None:
num_words = len(" ".join(page.title.split("_")).split(" "))
if num_words > self.title_words_limit:
if self.verbose:
print("Skipped %s" % page.title, flush=True)
continue
if self.verbose:
if self.views is not None:
print(
"Kept %s views: %s"
% (page.title, self.views[page.title]),
flush=True,
)
else:
print("Kept %s" % page.title, flush=True)
code = mwparserfromhell.parse(revision.text)
text = code.strip_code(
normalize=True, collapse=True, keep_template_params=False
)
title_url = str(page.title).replace(" ", "_")
metadata = dict(
title=page.title,
source="https://en.wikipedia.org/wiki/" + title_url,
id=page.id,
redirect=page.redirect,
views=self.views[page.title]
if self.views is not None
else -1,
)
metadata = {k: v for k, v in metadata.items() if v is not None}
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_index(search_term, index_filename):
byte_flag = False
data_length = start_byte = 0
index_file = open(index_filename, "r")
csv_reader = csv.reader(index_file, delimiter=":")
for line in csv_reader:
if not byte_flag and search_term == line[2]:
start_byte = int(line[0])
byte_flag = True
elif byte_flag and int(line[0]) != start_byte:
data_length = int(line[0]) - start_byte
break
index_file.close()
return start_byte, data_length
def get_start_bytes(index_filename):
index_file = open(index_filename, "r")
csv_reader = csv.reader(index_file, delimiter=":")
start_bytes = set()
for line in csv_reader:
start_bytes.add(int(line[0]))
index_file.close()
return sorted(start_bytes)
def get_wiki_filenames():
# requires
# wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2
base_path = os.path.join(
root_path, "enwiki-20230401-pages-articles-multistream"
)
index_file = "enwiki-20230401-pages-articles-multistream-index.txt"
index_filename = os.path.join(base_path, index_file)
wiki_filename = os.path.join(
base_path, "enwiki-20230401-pages-articles-multistream.xml.bz2"
)
return index_filename, wiki_filename
def get_documents_by_search_term(search_term):
index_filename, wiki_filename = get_wiki_filenames()
start_byte, data_length = search_index(search_term, index_filename)
with open(wiki_filename, "rb") as wiki_file:
wiki_file.seek(start_byte)
data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length))
loader = MWDumpDirectLoader(data.decode())
documents = loader.load()
return documents
def get_one_chunk(
wiki_filename,
start_byte,
end_byte,
return_file=True,
title_words_limit=None,
use_views=True,
):
data_length = end_byte - start_byte
with open(wiki_filename, "rb") as wiki_file:
wiki_file.seek(start_byte)
data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length))
loader = MWDumpDirectLoader(
data.decode(), title_words_limit=title_words_limit, use_views=use_views
)
documents1 = loader.load()
if return_file:
base_tmp = "temp_wiki"
if not os.path.isdir(base_tmp):
os.makedirs(base_tmp, exist_ok=True)
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle")
with open(filename, "wb") as f:
pickle.dump(documents1, f)
return filename
return documents1
from joblib import Parallel, delayed
global_views = get_views()
def get_all_documents(small_test=2, n_jobs=None, use_views=True):
print("DO get all wiki docs: %s" % small_test, flush=True)
index_filename, wiki_filename = get_wiki_filenames()
start_bytes = get_start_bytes(index_filename)
end_bytes = start_bytes[1:]
start_bytes = start_bytes[:-1]
if small_test:
start_bytes = start_bytes[:small_test]
end_bytes = end_bytes[:small_test]
if n_jobs is None:
n_jobs = 5
else:
if n_jobs is None:
n_jobs = os.cpu_count() // 4
# default loky backend leads to name space conflict problems
return_file = True # large return from joblib hangs
documents = Parallel(n_jobs=n_jobs, verbose=10, backend="multiprocessing")(
delayed(get_one_chunk)(
wiki_filename,
start_byte,
end_byte,
return_file=return_file,
use_views=use_views,
)
for start_byte, end_byte in zip(start_bytes, end_bytes)
)
if return_file:
# then documents really are files
files = documents.copy()
documents = []
for fil in files:
with open(fil, "rb") as f:
documents.extend(pickle.load(f))
os.remove(fil)
else:
from functools import reduce
from operator import concat
documents = reduce(concat, documents)
assert isinstance(documents, list)
print("DONE get all wiki docs", flush=True)
return documents
def test_by_search_term():
search_term = "Apollo"
assert len(get_documents_by_search_term(search_term)) == 100
search_term = "Abstract (law)"
assert len(get_documents_by_search_term(search_term)) == 100
search_term = "Artificial languages"
assert len(get_documents_by_search_term(search_term)) == 100
def test_start_bytes():
index_filename, wiki_filename = get_wiki_filenames()
assert len(get_start_bytes(index_filename)) == 227850
def test_get_all_documents():
small_test = 20 # 227850
n_jobs = os.cpu_count() // 4
assert (
len(
get_all_documents(
small_test=small_test, n_jobs=n_jobs, use_views=False
)
)
== small_test * 100
)
assert (
len(
get_all_documents(
small_test=small_test, n_jobs=n_jobs, use_views=True
)
)
== 429
)
def get_one_pageviews(fil):
df1 = pd.read_csv(
fil,
sep=" ",
header=None,
names=["region", "title", "views", "foo"],
quoting=csv.QUOTE_NONE,
)
df1.index = df1["title"]
df1 = df1[df1["region"] == "en"]
df1 = df1.drop("region", axis=1)
df1 = df1.drop("foo", axis=1)
df1 = df1.drop("title", axis=1) # already index
base_tmp = "temp_wiki_pageviews"
if not os.path.isdir(base_tmp):
os.makedirs(base_tmp, exist_ok=True)
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.csv")
df1.to_csv(filename, index=True)
return filename
def test_agg_pageviews(gen_files=False):
if gen_files:
path = os.path.join(
root_path,
"wiki_pageviews/dumps.wikimedia.org/other/pageviews/2023/2023-04",
)
files = glob.glob(os.path.join(path, "pageviews*.gz"))
# files = files[:2] # test
n_jobs = os.cpu_count() // 2
csv_files = Parallel(
n_jobs=n_jobs, verbose=10, backend="multiprocessing"
)(delayed(get_one_pageviews)(fil) for fil in files)
else:
# to continue without redoing above
csv_files = glob.glob(
os.path.join(root_path, "temp_wiki_pageviews/*.csv")
)
df_list = []
for csv_file in csv_files:
print(csv_file)
df1 = pd.read_csv(csv_file)
df_list.append(df1)
df = pd.concat(df_list, axis=0)
df = df.groupby("title")["views"].sum().reset_index()
df.to_csv("wiki_page_views.csv", index=True)
def test_reduce_pageview():
filename = "wiki_page_views.csv"
df = pd.read_csv(filename)
df = df[df["views"] < 1e7]
#
plt.hist(df["views"], bins=100, log=True)
views_avg = np.mean(df["views"])
views_median = np.median(df["views"])
plt.title("Views avg: %s median: %s" % (views_avg, views_median))
plt.savefig(filename.replace(".csv", ".png"))
plt.close()
#
views_limit = 5000
df = df[df["views"] > views_limit]
filename = "wiki_page_views_more_5000month.csv"
df.to_csv(filename, index=True)
#
plt.hist(df["views"], bins=100, log=True)
views_avg = np.mean(df["views"])
views_median = np.median(df["views"])
plt.title("Views avg: %s median: %s" % (views_avg, views_median))
plt.savefig(filename.replace(".csv", ".png"))
plt.close()
@pytest.mark.skip("Only if doing full processing again, some manual steps")
def test_do_wiki_full_all():
# Install other requirements for wiki specific conversion:
# pip install -r reqs_optional/requirements_optional_wikiprocessing.txt
# Use "Transmission" in Ubuntu to get wiki dump using torrent:
# See: https://meta.wikimedia.org/wiki/Data_dump_torrents
# E.g. magnet:?xt=urn:btih:b2c74af2b1531d0b63f1166d2011116f44a8fed0&dn=enwiki-20230401-pages-articles-multistream.xml.bz2&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337
# Get index
os.system(
"wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2"
)
# Test that can use LangChain to get docs from subset of wiki as sampled out of full wiki directly using bzip multistream
test_get_all_documents()
# Check can search wiki multistream
test_by_search_term()
# Test can get all start bytes in index
test_start_bytes()
# Get page views, e.g. for entire month of April 2023
os.system(
"wget -b -m -k -o wget.log -e robots=off https://dumps.wikimedia.org/other/pageviews/2023/2023-04/"
)
# Aggregate page views from many files into single file
test_agg_pageviews(gen_files=True)
# Reduce page views to some limit, so processing of full wiki is not too large
test_reduce_pageview()
# Start generate.py with requesting wiki_full in prep. This will use page views as referenced in get_views.
# Note get_views as global() function done once is required to avoid very slow processing
# WARNING: Requires alot of memory to handle, used up to 300GB system RAM at peak
"""
python generate.py --langchain_mode='wiki_full' --visible_langchain_modes="['wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']" &> lc_out.log
"""

View File

@@ -0,0 +1,121 @@
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from enums import PromptType
class StoppingCriteriaSub(StoppingCriteria):
def __init__(
self, stops=[], encounters=[], device="cuda", model_max_length=None
):
super().__init__()
assert (
len(stops) % len(encounters) == 0
), "Number of stops and encounters must match"
self.encounters = encounters
self.stops = [stop.to(device) for stop in stops]
self.num_stops = [0] * len(stops)
self.model_max_length = model_max_length
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stopi, stop in enumerate(self.stops):
if torch.all((stop == input_ids[0][-len(stop) :])).item():
self.num_stops[stopi] += 1
if (
self.num_stops[stopi]
>= self.encounters[stopi % len(self.encounters)]
):
# print("Stopped", flush=True)
return True
if (
self.model_max_length is not None
and input_ids[0].shape[0] >= self.model_max_length
):
# critical limit
return True
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
return False
def get_stopping(
prompt_type,
prompt_dict,
tokenizer,
device,
human="<human>:",
bot="<bot>:",
model_max_length=None,
):
# FIXME: prompt_dict unused currently
if prompt_type in [
PromptType.human_bot.name,
PromptType.instruct_vicuna.name,
PromptType.instruct_with_end.name,
]:
if prompt_type == PromptType.human_bot.name:
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
# stopping only starts once output is beyond prompt
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
stop_words = [human, bot, "\n" + human, "\n" + bot]
encounters = [1, 2]
elif prompt_type == PromptType.instruct_vicuna.name:
# even below is not enough, generic strings and many ways to encode
stop_words = [
"### Human:",
"""
### Human:""",
"""
### Human:
""",
"### Assistant:",
"""
### Assistant:""",
"""
### Assistant:
""",
]
encounters = [1, 2]
else:
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
stop_words = ["### End"]
encounters = [1]
stop_words_ids = [
tokenizer(stop_word, return_tensors="pt")["input_ids"].squeeze()
for stop_word in stop_words
]
# handle single token case
stop_words_ids = [
x if len(x.shape) > 0 else torch.tensor([x])
for x in stop_words_ids
]
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
# avoid padding in front of tokens
if (
tokenizer._pad_token
): # use hidden variable to avoid annoying properly logger bug
stop_words_ids = [
x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x
for x in stop_words_ids
]
# handle fake \n added
stop_words_ids = [
x[1:] if y[0] == "\n" else x
for x, y in zip(stop_words_ids, stop_words)
]
# build stopper
stopping_criteria = StoppingCriteriaList(
[
StoppingCriteriaSub(
stops=stop_words_ids,
encounters=encounters,
device=device,
model_max_length=model_max_length,
)
]
)
else:
stopping_criteria = StoppingCriteriaList()
return stopping_criteria

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,69 @@
from typing import Any, Dict, List, Union, Optional
import time
import queue
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
class StreamingGradioCallbackHandler(BaseCallbackHandler):
"""
Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
"""
def __init__(self, timeout: Optional[float] = None, block=True):
super().__init__()
self.text_queue = queue.SimpleQueue()
self.stop_signal = None
self.do_stop = False
self.timeout = timeout
self.block = block
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running. Clean the queue."""
while not self.text_queue.empty():
try:
self.text_queue.get(block=False)
except queue.Empty:
continue
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
self.text_queue.put(token)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.text_queue.put(self.stop_signal)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
self.text_queue.put(self.stop_signal)
def __iter__(self):
return self
def __next__(self):
while True:
try:
value = (
self.stop_signal
) # value looks unused in pycharm, not true
if self.do_stop:
print("hit stop", flush=True)
# could raise or break, maybe best to raise and make parent see if any exception in thread
raise StopIteration()
# break
value = self.text_queue.get(
block=self.block, timeout=self.timeout
)
break
except queue.Empty:
time.sleep(0.01)
if value == self.stop_signal:
raise StopIteration()
else:
return value

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,41 @@
import torch
from transformers import AutoModelForCausalLM
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
class FirstVicuna(torch.nn.Module):
def __init__(self, model_path):
def __init__(
self,
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("First Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(self, input_ids):
op = self.model(input_ids=input_ids, use_cache=True)
@@ -22,12 +49,36 @@ class FirstVicuna(torch.nn.Module):
class SecondVicuna(torch.nn.Module):
def __init__(self, model_path):
def __init__(
self,
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,

View File

@@ -1,612 +0,0 @@
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.utils import (
get_vmfb_from_path,
)
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx, get_f16_inputs
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import torch
import torch_mlir
import os
class Vicuna(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
max_num_tokens=512,
device="cuda",
precision="fp32",
first_vicuna_mlir_path=None,
second_vicuna_mlir_path=None,
first_vicuna_vmfb_path=None,
second_vicuna_vmfb_path=None,
load_mlir_from_shark_tank=True,
low_device_memory=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
if not load_mlir_from_shark_tank and precision in ["int4", "int8"]:
print(
"int4 and int8 are only available from SHARK tank, please set --load_mlir_from_shark_tank, using fp32 now"
)
precision = "fp32"
self.precision = precision
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
self.first_vicuna_mlir_path = first_vicuna_mlir_path
self.second_vicuna_mlir_path = second_vicuna_mlir_path
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
self.low_device_memory = low_device_memory
self.first_vic = None
self.second_vic = None
if self.first_vicuna_mlir_path == None:
self.first_vicuna_mlir_path = self.get_model_path()
if self.second_vicuna_mlir_path == None:
self.second_vicuna_mlir_path = self.get_model_path("second")
if self.first_vicuna_vmfb_path == None:
self.first_vicuna_vmfb_path = self.get_model_path(suffix="vmfb")
if self.second_vicuna_vmfb_path == None:
self.second_vicuna_vmfb_path = self.get_model_path(
"second", "vmfb"
)
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
def get_model_path(self, model_number="first", suffix="mlir"):
safe_device = "_".join(self.device.split("-"))
if suffix == "mlir":
return Path(f"{model_number}_vicuna_{self.precision}.{suffix}")
return Path(
f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}"
)
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return vicuna_model
def compile_first_vicuna(self):
vmfb = get_vmfb_from_path(
self.first_vicuna_vmfb_path, self.device, "tm_tensor"
)
if vmfb is not None:
return vmfb
# Compilation path needs some more work before it is functional
print(
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with\n"
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.first_vicuna_mlir_path.exists():
with open(self.first_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision in ["fp32", "fp16", "int8", "int4"]:
# download MLIR from shark_tank
download_public_file(
f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
self.first_vicuna_mlir_path.absolute(),
single_file=True,
)
if self.first_vicuna_mlir_path.exists():
with open(self.first_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
f"Only fp32/fp16/int8/int4 mlir added to tank, generating {self.precision} mlir on device."
)
if not mlir_generated:
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = self.tokenizer(
compilation_prompt
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(self.hf_model_path)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
print(f"[DEBUG] generating torch mlir")
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
def remove_constant_dim(line):
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim)", line
)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
return line
module = str(module)
new_lines = []
print(f"[DEBUG] rewriting torch_mlir file")
for line in module.splitlines():
line = remove_constant_dim(line)
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
new_lines.append(
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
)
if (
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
in line
):
continue
new_lines.append(line)
module = "\n".join(new_lines)
print(f"[DEBUG] converting to bytecode")
del new_lines
module = module.encode("UTF-8")
module = BytesIO(module)
bytecode = module.read()
del module
print(f"[DEBUG] writing mlir to file")
f_ = open(self.first_vicuna_mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
path = shark_module.save_module(
self.first_vicuna_vmfb_path.parent.absolute(),
self.first_vicuna_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
print("Saved first vic vmfb at vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile_second_vicuna(self):
vmfb = get_vmfb_from_path(
self.second_vicuna_vmfb_path, self.device, "tm_tensor"
)
if vmfb is not None:
return vmfb
# Compilation path needs some more work before it is functional
print(
f"[DEBUG] mlir path {self.second_vicuna_mlir_path} {'exists' if self.second_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.second_vicuna_mlir_path.exists():
with open(self.second_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision in ["fp32", "fp16", "int8", "int4"]:
# download MLIR from shark_tank
download_public_file(
f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
self.second_vicuna_mlir_path.absolute(),
single_file=True,
)
if self.second_vicuna_mlir_path.exists():
with open(self.second_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32/fp16/int8/int4 mlir added to tank, generating mlir on device."
)
if not mlir_generated:
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(self.hf_model_path)
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False] + [True] * 64,
mlir_type="torchscript",
)
if self.precision == "fp16":
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
f16_input_mask=[False] + [True] * 64,
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[
i
] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
def remove_constant_dim(line):
if "c19_i64" in line:
line = re.sub("c19_i64", "dim_i64", line)
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim)", line
)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
if "20x" in line:
line = re.sub("20x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dimp1)", line
)
if " 20," in line:
line = re.sub(" 20,", " %dimp1,", line)
return line
module_str = str(module)
new_lines = []
for line in module_str.splitlines():
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
)
continue
if "%c2 = arith.constant 2 : index" in line:
continue
if "%c20_i64 = arith.constant 20 : i64" in line:
new_lines.append("%c1_i64 = arith.constant 1 : i64")
new_lines.append(
"%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
)
new_lines.append(
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
)
continue
line = remove_constant_dim(line)
new_lines.append(line)
module_str = "\n".join(new_lines)
bytecode = module_str.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(self.second_vicuna_mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
path = shark_module.save_module(
self.second_vicuna_vmfb_path.parent.absolute(),
self.second_vicuna_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
print("Saved vmfb at ", str(path))
shark_module.load_module(self.second_vicuna_vmfb_path)
# self.shark_module = shark_module
return shark_module
def compile(self):
# Cannot load both the models in the memory at once
# due to memory constraints, hence on demand compilation
# is being used until the space is enough for both models
# Testing : DO NOT Download Vmfbs if not found. Modify later
# download vmfbs for A100
if (
not self.first_vicuna_vmfb_path.exists()
and self.device in ["cuda", "cpu"]
and self.precision in ["fp32", "fp16"]
):
# combinations that are still in the works
if not (self.device == "cuda" and self.precision == "fp16"):
# Will generate vmfb on device
pass
else:
download_public_file(
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
self.first_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get first vic
# TODO: Remove after testing to avoid memory overload
# fvic_shark_model = self.compile_first_vicuna()
pass
if (
not self.second_vicuna_vmfb_path.exists()
and self.device in ["cuda", "cpu"]
and self.precision in ["fp32", "fp16"]
):
# combinations that are still in the works
if not (self.device == "cuda" and self.precision == "fp16"):
# Will generate vmfb on device
pass
else:
download_public_file(
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
self.second_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get second vic
# TODO: Remove after testing to avoid memory overload
# svic_shark_model = self.compile_second_vicuna()
pass
return None
# return tuple of shark_modules once mem is supported
# return fvic_shark_model, svic_shark_model
def decode_tokens(self, res_tokens):
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = self.tokenizer.decode(res_tokens)
return res_str
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
import gc
if not self.low_device_memory:
if self.first_vic == None:
self.first_vic = self.compile_first_vicuna()
if self.second_vic == None:
self.second_vic = self.compile_second_vicuna()
res_tokens = []
params = {
"prompt": prompt,
"is_first": True,
"fv": self.compile_first_vicuna()
if self.first_vic == None
else self.first_vic,
}
generated_token_op = self.generate_new_token(params=params)
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
yield detok
res_tokens.append(token)
if cli:
print(f"Assistant: {detok}", end=" ", flush=True)
# Clear First Vic from Memory (main and cuda)
if self.low_device_memory:
del params
torch.cuda.empty_cache()
gc.collect()
for _ in range(self.max_num_tokens - 2):
params = {
"prompt": None,
"is_first": False,
"logits": logits,
"pkv": pkv,
"sv": self.compile_second_vicuna()
if self.second_vic == None
else self.second_vic,
}
generated_token_op = self.generate_new_token(params=params)
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
if token == 2:
break
res_tokens.append(token)
if detok == "<0x0A>":
if cli:
print("\n", end="", flush=True)
else:
if cli:
print(f"{detok}", end=" ", flush=True)
if len(res_tokens) % 3 == 0:
part_str = self.decode_tokens(res_tokens)
yield part_str
if self.device == "cuda":
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()
res_str = self.decode_tokens(res_tokens)
# print(f"[DEBUG] final output : \n{res_str}")
yield res_str
def generate_new_token(self, params, debug=False):
def forward_first(first_vic, prompt, cache_outputs=False):
input_ids = self.tokenizer(prompt).input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
firstVicunaInput = (input_ids,)
assert first_vic is not None
output_first_vicuna = first_vic("forward", firstVicunaInput)
output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:])
logits_first_vicuna = torch.tensor(output_first_vicuna[0])
if cache_outputs:
torch.save(
logits_first_vicuna, "logits_first_vicuna_tensor.pt"
)
torch.save(
output_first_vicuna_tensor, "output_first_vicuna_tensor.pt"
)
token = torch.argmax(
torch.tensor(logits_first_vicuna)[:, -1, :], dim=1
)
return token, logits_first_vicuna, output_first_vicuna_tensor
def forward_second(sec_vic, inputs=None, load_inputs=False):
if inputs is not None:
logits = inputs[0]
pkv = inputs[1:]
elif load_inputs:
pkv = torch.load("output_first_vicuna_tensor.pt")
pkv = tuple(torch.tensor(x) for x in pkv)
logits = torch.load("logits_first_vicuna_tensor.pt")
else:
print(
"Either inputs must be given, or load_inputs must be true"
)
return None
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
secondVicunaOutput = sec_vic("forward", secondVicunaInput)
new_pkv = secondVicunaOutput[1:]
new_logits = secondVicunaOutput[0]
new_token = torch.argmax(torch.tensor(new_logits)[:, -1, :], dim=1)
return new_token, new_logits, new_pkv
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
fv = params["fv"]
token, logits, pkv = forward_first(
fv, # self.shark_model[0],
prompt=prompt,
cache_outputs=False,
)
else:
_logits = params["logits"]
_pkv = params["pkv"]
inputs = (_logits,) + tuple(_pkv)
sv = params["sv"]
token, logits, pkv = forward_second(
sv, # self.shark_model[1],
inputs=inputs,
load_inputs=False,
)
detok = self.tokenizer.decode(token)
if debug:
print(
f"[DEBUG] is_first: {is_first} |"
f" token : {token} | detok : {detok}"
)
ret_dict = {
"token": token,
"logits": logits,
"pkv": pkv,
"detok": detok,
}
return ret_dict
def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
pass

View File

@@ -1,686 +0,0 @@
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
FirstVicunaLayer,
SecondVicunaLayer,
CompiledVicunaLayer,
ShardedVicunaModel,
LMHead,
LMHeadCompiled,
VicunaEmbedding,
VicunaEmbeddingCompiled,
VicunaNorm,
VicunaNormCompiled,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from shark.shark_importer import import_with_fx
from io import BytesIO
from pathlib import Path
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from torch_mlir import TensorPlaceholder
import re
import torch
import torch_mlir
import os
import json
class Vicuna(SharkLLMBase):
# Class representing Sharded Vicuna Model
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
max_num_tokens=512,
device="cuda",
precision="fp32",
config_json=None,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.tokenizer = self.get_tokenizer()
self.config = config_json
self.shark_model = self.compile(device=device)
def get_tokenizer(self):
# Retrieve the tokenizer from Huggingface
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
# Retrieve the torch model from Huggingface
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return vicuna_model
def write_in_dynamic_inputs0(self, module, dynamic_input_size):
# Current solution for ensuring mlir files support dynamic inputs
# TODO find a more elegant way to implement this
new_lines = []
for line in module.splitlines():
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def write_in_dynamic_inputs1(self, module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
if "dim_42 =" in line:
continue
if f"%c{dynamic_input_size}_i64 =" in line:
new_lines.append(
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
)
new_lines.append(
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
)
continue
line = re.sub(f"{dynamic_input_size}x", "?x", line)
line = re.sub(f"%c{dynamic_input_size}_i64", "%dim_42_i64", line)
if "?x" in line:
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim_42)", line
)
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim_42\)",
"tensor.empty(%dim_42, %dim_42)",
line,
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def combine_mlir_scripts(
self, first_vicuna_mlir, second_vicuna_mlir, output_name
):
maps1 = []
maps2 = []
constants = set()
f1 = []
f2 = []
for line in first_vicuna_mlir.splitlines():
if re.search("#map\d*\s*=", line):
maps1.append(line)
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "first_vicuna_forward", line)
f1.append(line)
f1 = f1[:-1]
for i, map_line in enumerate(maps1):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
maps1[i] = map_line
f1 = [
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
for func_line in f1
]
for line in second_vicuna_mlir.splitlines():
if re.search("#map\d*\s*=", line):
maps2.append(line)
elif "global_seed" in line:
continue
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
f2 = f2[:-1]
for i, map_line in enumerate(maps2):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
maps2[i] = map_line
f2 = [
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
for func_line in f2
]
module_start = (
'module attributes {torch.debug_module_name = "_lambda"} {'
)
module_end = "}"
global_vars = []
vnames = []
vdtypes = []
global_var_loading1 = []
global_var_loading2 = []
for constant in list(constants):
vname, vbody = constant.split("=")
vname = re.sub("%", "", vname)
vname = vname.strip()
vbody = re.sub("arith.constant", "", vbody)
vbody = vbody.strip()
vdtype = vbody.split(":")[1].strip()
fixed_vdtype = vdtype
vdtypes.append(vdtype)
vdtype = re.sub("\d{1,}x", "?x", vdtype)
vnames.append(vname)
global_vars.append(
f"ml_program.global public @{vname}({vbody}) : {fixed_vdtype}"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
new_f1, new_f2 = [], []
for line in f1:
if "func.func" in line:
new_f1.append(line)
for global_var in global_var_loading1:
new_f1.append(global_var)
else:
new_f1.append(line)
for line in f2:
if "func.func" in line:
new_f2.append(line)
for global_var in global_var_loading1:
new_f2.append(global_var)
else:
new_f2.append(line)
f1 = new_f1
f2 = new_f2
whole_string = "\n".join(
maps1
+ maps2
+ [module_start]
+ global_vars
+ f1
+ f2
+ [module_end]
)
f_ = open(output_name, "w+")
f_.write(whole_string)
f_.close()
return whole_string
def compile_vicuna_layer(
self,
vicuna_layer,
hidden_states,
attention_mask,
position_ids,
past_key_value0=None,
past_key_value1=None,
):
# Compile a hidden decoder layer of vicuna
if past_key_value0 is None and past_key_value1 is None:
model_inputs = (hidden_states, attention_mask, position_ids)
else:
model_inputs = (
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
)
mlir_bytecode = import_with_fx(
vicuna_layer,
model_inputs,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
return mlir_bytecode
def get_device_index(self, layer_string):
# Get the device index from the config file
# In the event that different device indices are assigned to
# different parts of a layer, a majority vote will be taken and
# everything will be run on the most commonly used device
if self.config is None:
return None
idx_votes = {}
for key in self.config.keys():
if re.search(layer_string, key):
if int(self.config[key]["gpu"]) in idx_votes.keys():
idx_votes[int(self.config[key]["gpu"])] += 1
else:
idx_votes[int(self.config[key]["gpu"])] = 1
device_idx = max(idx_votes, key=idx_votes.get)
return device_idx
def compile_lmhead(
self, lmh, hidden_states, device="cpu", device_idx=None
):
# compile the lm head of the vicuna model
# This can be used for both first and second vicuna, so only needs to be run once
mlir_path = Path(f"lmhead.mlir")
vmfb_path = Path(f"lmhead.vmfb")
if mlir_path.exists():
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
hidden_states = torch_mlir.TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
)
module = torch_mlir.compile(
lmh,
(hidden_states,),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
bytecode,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
else:
shark_module.save_module(module_name="lmhead")
shark_module.load_module(vmfb_path)
compiled_module = LMHeadCompiled(shark_module)
return compiled_module
def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None):
# compile the normalization layer of the vicuna model
# This can be used for both first and second vicuna, so only needs to be run once
mlir_path = Path(f"norm.mlir")
vmfb_path = Path(f"norm.vmfb")
if mlir_path.exists():
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
hidden_states = torch_mlir.TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
)
module = torch_mlir.compile(
fvn,
(hidden_states,),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
bytecode,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
else:
shark_module.save_module(module_name="norm")
shark_module.load_module(vmfb_path)
compiled_module = VicunaNormCompiled(shark_module)
return compiled_module
def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None):
# compile the embedding layer of the vicuna model
# This can be used for both first and second vicuna, so only needs to be run once
mlir_path = Path(f"embedding.mlir")
vmfb_path = Path(f"embedding.vmfb")
if mlir_path.exists():
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
input_ids = torch_mlir.TensorPlaceholder.like(
input_ids, dynamic_axes=[1]
)
module = torch_mlir.compile(
fve,
(input_ids,),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
bytecode,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
else:
shark_module.save_module(module_name="embedding")
shark_module.load_module(vmfb_path)
compiled_module = VicunaEmbeddingCompiled(shark_module)
return compiled_module
def compile_to_vmfb_one_model(
self, inputs0, layers0, inputs1, layers1, device="cpu"
):
mlirs, modules = [], []
assert len(layers0) == len(layers1)
for layer0, layer1, idx in zip(layers0, layers1, range(len(layers0))):
mlir_path = Path(f"{idx}_full.mlir")
vmfb_path = Path(f"{idx}_full.vmfb")
# if vmfb_path.exists():
# continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
mlirs.append(bytecode)
else:
hidden_states_placeholder0 = TensorPlaceholder.like(
inputs0[0], dynamic_axes=[1]
)
attention_mask_placeholder0 = TensorPlaceholder.like(
inputs0[1], dynamic_axes=[3]
)
position_ids_placeholder0 = TensorPlaceholder.like(
inputs0[2], dynamic_axes=[1]
)
hidden_states_placeholder1 = TensorPlaceholder.like(
inputs1[0], dynamic_axes=[1]
)
attention_mask_placeholder1 = TensorPlaceholder.like(
inputs1[1], dynamic_axes=[3]
)
position_ids_placeholder1 = TensorPlaceholder.like(
inputs1[2], dynamic_axes=[1]
)
pkv0_placeholder = TensorPlaceholder.like(
inputs1[3], dynamic_axes=[2]
)
pkv1_placeholder = TensorPlaceholder.like(
inputs1[4], dynamic_axes=[2]
)
print(f"Compiling layer {idx} mlir")
ts_g = self.compile_vicuna_layer(
layer0, inputs0[0], inputs0[1], inputs0[2]
)
module0 = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder0,
inputs0[1],
inputs0[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
module0 = self.write_in_dynamic_inputs0(str(module0), 137)
ts_g = self.compile_vicuna_layer(
layer1,
inputs1[0],
inputs1[1],
inputs1[2],
inputs1[3],
inputs1[4],
)
module1 = torch_mlir.compile(
ts_g,
(
inputs1[0],
attention_mask_placeholder1,
inputs1[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
module1 = self.write_in_dynamic_inputs1(str(module1), 138)
module_combined = self.combine_mlir_scripts(
module0, module1, f"{idx}_full.mlir"
)
mlirs.append(module_combined)
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
device_idx = self.get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
None,
device=device,
device_idx=idx % 4,
mlir_dialect="tm_tensor",
mmap=False,
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
device_idx = self.get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
mlirs[idx],
device=device,
device_idx=idx % 4,
mlir_dialect="tm_tensor",
mmap=False,
)
module.save_module(
module_name=f"{idx}_full",
extra_args=[
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
return mlirs, modules
def get_sharded_model(self, device="cpu"):
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it
SAMPLE_INPUT_LEN = 137
vicuna_model = self.get_src_model()
placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
)
placeholder_input1 = (
torch.zeros([1, 1, 4096]),
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
torch.zeros([1, 1], dtype=torch.int64),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
norm = VicunaNorm(vicuna_model.model.norm)
device_idx = self.get_device_index(
r"vicuna\.model\.model\.norm(?:\.|\s|$)"
)
print(device_idx)
norm = self.compile_norm(
norm,
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
device=self.device,
device_idx=device_idx,
)
embeddings = VicunaEmbedding(vicuna_model.model.embed_tokens)
device_idx = self.get_device_index(
r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)"
)
print(device_idx)
embeddings = self.compile_embedding(
embeddings,
(torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)),
device=self.device,
device_idx=device_idx,
)
lmhead = LMHead(vicuna_model.lm_head)
device_idx = self.get_device_index(
r"vicuna\.model\.lm_head(?:\.|\s|$)"
)
print(device_idx)
lmhead = self.compile_lmhead(
lmhead,
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
device=self.device,
device_idx=device_idx,
)
layers0 = [
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
]
layers1 = [
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules = self.compile_to_vmfb_one_model(
placeholder_input0,
layers0,
placeholder_input1,
layers1,
device=device,
)
shark_layers = [CompiledVicunaLayer(m) for m in modules]
sharded_model = ShardedVicunaModel(
vicuna_model,
shark_layers,
lmhead,
embeddings,
norm,
)
return sharded_model
def compile(self, device="cpu"):
return self.get_sharded_model(device=device)
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
tokens_generated = []
_past_key_values = None
_token = None
detoks_generated = []
for iteration in range(self.max_num_tokens):
params = {
"prompt": prompt,
"is_first": iteration == 0,
"token": _token,
"past_key_values": _past_key_values,
}
generated_token_op = self.generate_new_token(params=params)
_token = generated_token_op["token"]
_past_key_values = generated_token_op["past_key_values"]
_detok = generated_token_op["detok"]
if _token == 2:
break
detoks_generated.append(_detok)
tokens_generated.append(_token)
for i in range(len(tokens_generated)):
if type(tokens_generated[i]) != int:
tokens_generated[i] = int(tokens_generated[i][0])
result_output = self.tokenizer.decode(tokens_generated)
return result_output
def generate_new_token(self, params):
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
input_ids = self.tokenizer(prompt).input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(input_ids, is_first=is_first)
else:
token = params["token"]
past_key_values = params["past_key_values"]
input_ids = [token]
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(
input_ids, past_key_values=past_key_values, is_first=is_first
)
_logits = output["logits"]
_past_key_values = output["past_key_values"]
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
_detok = self.tokenizer.decode(_token)
ret_dict = {
"token": _token,
"detok": _detok,
"past_key_values": _past_key_values,
}
print(f" token : {_token} | detok : {_detok}")
return ret_dict
def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
pass

View File

@@ -58,11 +58,8 @@ def main():
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = inpaint_obj.generate_images(
args.prompts,
@@ -76,7 +73,7 @@ def main():
args.inpaint_full_res_padding,
args.steps,
args.guidance_scale,
seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
@@ -90,7 +87,10 @@ def main():
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += f"seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)

View File

@@ -51,11 +51,8 @@ def main():
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = outpaint_obj.generate_images(
args.prompts,
@@ -74,7 +71,7 @@ def main():
args.width,
args.steps,
args.guidance_scale,
seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
@@ -88,7 +85,10 @@ def main():
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += f"seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)

View File

@@ -223,7 +223,8 @@ def lora_train(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, both must not be "
"empty.",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:

View File

@@ -42,11 +42,8 @@ def main():
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
@@ -56,7 +53,7 @@ def main():
args.width,
args.steps,
args.guidance_scale,
seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
@@ -70,7 +67,12 @@ def main():
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += (
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
)
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)

View File

@@ -21,7 +21,7 @@ if __name__ == "__main__":
print("Flag --img_path is required.")
exit()
# When the models get uploaded, it should be default to False.
# When the models get uploaded, it should be defaulted to False.
args.import_mlir = True
cpu_scheduling = not args.scheduler.startswith("Shark")

View File

@@ -1,58 +1,13 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += copy_metadata('Pillow')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv-python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('tkinter')
datas += collect_data_files('webview')
datas += collect_data_files('sentencepiece')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
( 'web/ui/css/*', 'ui/css' ),
( 'web/ui/logos/*', 'logos' )
]
from apps.stable_diffusion.shark_studio_imports import pathex, datas, hiddenimports
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['web/index.py'],
pathex=['.'],
pathex=pathex,
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
@@ -74,7 +29,7 @@ exe = EXE(
a.zipfiles,
a.datas,
[],
name='shark_sd',
name='nodai_shark_studio',
debug=False,
bootloader_ignore_signals=False,
strip=False,

View File

@@ -29,6 +29,7 @@ datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('py-cpuinfo')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),

View File

@@ -0,0 +1,64 @@
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
# python path for pyinstaller
pathex = [".", "./apps/language_models/langchain"]
# datafiles for pyinstaller
datas = []
datas += collect_data_files("torch")
datas += copy_metadata("torch")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
datas += copy_metadata("packaging")
datas += copy_metadata("filelock")
datas += copy_metadata("numpy")
datas += copy_metadata("importlib_metadata")
datas += copy_metadata("torch-mlir")
datas += copy_metadata("omegaconf")
datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("pytorch_lightning")
datas += collect_data_files("opencv_python")
datas += collect_data_files("skimage")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree")
datas += collect_data_files("google_cloud_storage")
datas += collect_data_files("shark")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += [
("src/utils/resources/prompts.json", "resources"),
("src/utils/resources/model_db.json", "resources"),
("src/utils/resources/opt_flags.json", "resources"),
("src/utils/resources/base_model.json", "resources"),
("web/ui/css/*", "ui/css"),
("web/ui/logos/*", "logos"),
]
# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("transformers") if "tests" not in x
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]

View File

@@ -45,6 +45,7 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(width * mul_val)
elif "/" in shape[i]:
import math
div_val = int(shape[i].split("/")[1])
if "batch_size" in shape[i]:
new_shape.append(math.ceil(batch_size / div_val))
@@ -59,7 +60,9 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
def check_compilation(model, model_name):
if not model:
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
raise Exception(
f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
)
class SharkifyStableDiffusionModel:
@@ -97,16 +100,22 @@ class SharkifyStableDiffusionModel:
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
weights_path = (
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
)
if not os.path.isfile(weights_path):
subprocess.run(["wget", custom_weights, "-O", weights_path])
subprocess.run(
["wget", custom_weights, "-O", weights_path]
)
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights = get_path_to_diffusers_checkpoint(
custom_weights
)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
@@ -126,7 +135,7 @@ class SharkifyStableDiffusionModel:
+ "_"
+ precision
)
print(f'use_tuned? sharkify: {use_tuned}')
print(f"use_tuned? sharkify: {use_tuned}")
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
@@ -163,14 +172,24 @@ class SharkifyStableDiffusionModel:
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "unet512", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
sub_model_list = [
"clip",
"unet",
"unet512",
"stencil_unet",
"vae",
"vae_encode",
"stencil_adaptor",
]
index = 0
for model in sub_model_list:
sub_model = model
model_config = self.model_name
if "vae" == model:
if self.custom_vae != "":
model_config = model_config + get_path_stem(self.custom_vae)
model_config = model_config + get_path_stem(
self.custom_vae
)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
@@ -197,7 +216,11 @@ class SharkifyStableDiffusionModel:
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, self.max_len, self.width, self.height, self.batch_size
shape,
self.max_len,
self.width,
self.height,
self.batch_size,
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
@@ -209,10 +232,12 @@ class SharkifyStableDiffusionModel:
sys.exit("shape isn't specified correctly.")
input_map.append(tensor)
return input_map
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
@@ -226,7 +251,11 @@ class SharkifyStableDiffusionModel:
vae_encode = VaeEncodeModel()
inputs = tuple(self.inputs["vae_encode"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
is_f16 = (
True
if not self.is_upscaler and self.precision == "fp16"
else False
)
shark_vae_encode, vae_encode_mlir = compile_through_fx(
vae_encode,
inputs,
@@ -243,7 +272,13 @@ class SharkifyStableDiffusionModel:
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
@@ -279,7 +314,11 @@ class SharkifyStableDiffusionModel:
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
is_f16 = (
True
if not self.is_upscaler and self.precision == "fp16"
else False
)
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
@@ -303,7 +342,10 @@ class SharkifyStableDiffusionModel:
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -316,12 +358,43 @@ class SharkifyStableDiffusionModel:
self.in_channels = self.unet.in_channels
self.train(False)
def forward( self, latent, timestep, text_embedding, guidance_scale, control1,
control2, control3, control4, control5, control6, control7,
control8, control9, control10, control11, control12, control13,
def forward(
self,
latent,
timestep,
text_embedding,
guidance_scale,
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
control13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple([ control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12,])
db_res_samples = tuple(
[
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]
)
mb_res_samples = control13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
@@ -342,7 +415,25 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
input_mask = [
True,
True,
True,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
@@ -386,16 +477,23 @@ class SharkifyStableDiffusionModel:
stencil_image = torch.cat(
[stencil_image_input] * 2
) # needs to be same as controlledUNET latents
down_block_res_samples, mid_block_res_sample = self.cnet.forward(
(
down_block_res_samples,
mid_block_res_sample,
) = self.cnet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
controlnet_cond=stencil_image,
return_dict=False,
)
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
return tuple(
list(down_block_res_samples) + [mid_block_res_sample]
)
scnet = StencilControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
scnet = StencilControlNetModel(
low_cpu_mem_usage=self.low_cpu_mem_usage
)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
@@ -417,7 +515,12 @@ class SharkifyStableDiffusionModel:
def get_unet(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
@@ -428,15 +531,24 @@ class SharkifyStableDiffusionModel:
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.config.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
if(args.attention_slicing.isdigit()):
self.unet.set_attention_slice(int(args.attention_slicing))
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
# TODO: Instead of flattening the `control` try to use the list.
def forward(
self, latent, timestep, text_embedding, guidance_scale,
self,
latent,
timestep,
text_embedding,
guidance_scale,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
@@ -452,16 +564,22 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if(use_large):
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (inputs[0],
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3])
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet512"])
inputs[3],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["unet512"]
)
else:
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
save_dir = os.path.join(
self.sharktank_dir, self.model_name["unet"]
)
input_mask = [True, True, True, False]
if self.debug:
os.makedirs(
@@ -489,7 +607,9 @@ class SharkifyStableDiffusionModel:
def get_unet_upscaler(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
@@ -512,13 +632,15 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if(use_large):
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (inputs[0],
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3])
inputs[3],
)
input_mask = [True, True, True, False]
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
@@ -538,7 +660,12 @@ class SharkifyStableDiffusionModel:
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
@@ -546,7 +673,9 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
update_lora_weight(
self.text_encoder, use_lora, "text_encoder"
)
def forward(self, input):
return self.text_encoder(input)[0]
@@ -585,16 +714,24 @@ class SharkifyStableDiffusionModel:
vae_checkpoint = None
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
if custom_vae.endswith(".ckpt"):
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
vae_checkpoint = torch.load(
self.custom_vae, map_location="cpu"
)
else:
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
vae_checkpoint = safetensors.torch.load_file(
self.custom_vae, device="cpu"
)
if "state_dict" in vae_checkpoint:
vae_checkpoint = vae_checkpoint["state_dict"]
try:
vae_checkpoint = convert_original_vae(vae_checkpoint)
finally:
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
vae_dict = {
k: v
for k, v in vae_checkpoint.items()
if k[0:4] != "loss" and k not in vae_ignore_keys
}
return vae_dict
def compile_unet_variants(self, model, use_large=False):
@@ -603,7 +740,10 @@ class SharkifyStableDiffusionModel:
return self.get_unet_upscaler(use_large=use_large)
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
from apps.stable_diffusion.src.models.opt_params import (
get_unet,
)
return get_unet()
else:
return self.get_unet(use_large=use_large)
@@ -612,7 +752,9 @@ class SharkifyStableDiffusionModel:
def vae_encode(self):
try:
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
self.inputs["vae_encode"] = self.get_input_info_for(
base_models["vae_encode"]
)
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
check_compilation(compiled_vae_encode, "Vae Encode")
@@ -641,18 +783,28 @@ class SharkifyStableDiffusionModel:
unet_inputs = base_models[model]
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[self.base_model_id]
)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[model_id]
)
try:
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
print(
"Retrying with a different base model configuration"
)
continue
# -- Once a successful compilation has taken place we'd want to store
@@ -675,7 +827,11 @@ class SharkifyStableDiffusionModel:
def vae(self):
try:
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
vae_input = (
base_models["vae"]["vae_upscaler"]
if self.is_upscaler
else base_models["vae"]["vae"]
)
self.inputs["vae"] = self.get_input_info_for(vae_input)
is_base_vae = self.base_vae
@@ -693,7 +849,9 @@ class SharkifyStableDiffusionModel:
def controlnet(self):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
check_compilation(compiled_stencil_adaptor, "Stencil")

View File

@@ -17,9 +17,13 @@ hf_model_variant_map = {
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
"stabilityai/stable-diffusion-2-inpainting": [
"stablediffusion",
"inpaint_v2",
],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
@@ -27,9 +31,12 @@ def get_quantize_model():
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
sys.exit(
"The int8 quantized model currently requires the height and width to be 512, and max_length to be 77"
)
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]

View File

@@ -15,6 +15,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -38,6 +43,11 @@ class Image2ImagePipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,

View File

@@ -14,6 +14,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -37,6 +42,11 @@ class InpaintPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,

View File

@@ -14,6 +14,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -38,6 +43,11 @@ class OutpaintPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,

View File

@@ -14,6 +14,12 @@ from diffusers import (
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -38,6 +44,12 @@ class StencilPipeline(StableDiffusionPipeline):
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,

View File

@@ -13,6 +13,10 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -34,6 +38,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,

View File

@@ -17,6 +17,9 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -67,6 +70,11 @@ class UpscalerPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
low_res_scheduler: Union[
DDIMScheduler,
@@ -78,6 +86,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,

View File

@@ -15,6 +15,9 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
@@ -48,6 +51,10 @@ class StableDiffusionPipeline:
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -67,7 +74,8 @@ class StableDiffusionPipeline:
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
# TODO: Find a better workaround for fetching base_model_id early
# enough for CLIPTokenizer.
try:
self.tokenizer = get_tokenizer()
except:
@@ -82,7 +90,8 @@ class StableDiffusionPipeline:
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder = self.sd_model.clip()
else:
@@ -310,6 +319,10 @@ class StableDiffusionPipeline:
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
import_mlir: bool,
model_id: str,
@@ -394,16 +407,21 @@ class StableDiffusionPipeline:
prompt (`str` or `list(int)`):
prompt to be encoded
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation.
Ignored when not using guidance
(i.e., ignored if `guidance_scale` is less than `1`).
model_max_length (int):
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
SHARK: pass the max length instead of relying on
pipe.tokenizer.model_max_length
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not,
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
SHARK: must be set to True as we always expect neg embeddings
(defaulted to True)
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
The max multiple length of prompt embeddings compared to the
max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error
(defaulted to 1)
num_images_per_prompt (`int`):
number of images that should be generated per prompt
SHARK: num_images_per_prompt is not used (defaulted to 1)
@@ -422,9 +440,11 @@ class StableDiffusionPipeline:
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
f"`negative_prompt`: "
f"{negative_prompt} has batch size {len(negative_prompt)}, "
f"but `prompt`: {prompt} has batch size {batch_size}. "
f"Please make sure that passed `negative_prompt` matches "
"the batch size of `prompt`."
)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
@@ -437,14 +457,36 @@ class StableDiffusionPipeline:
)
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = text_embeddings.shape
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# text_embeddings = text_embeddings.repeat(
# 1,
# num_images_per_prompt,
# 1
# )
# text_embeddings = (
# text_embeddings.view(
# bs_embed * num_images_per_prompt,
# seq_len,
# -1
# )
# )
if do_classifier_free_guidance:
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = uncond_embeddings.shape
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# uncond_embeddings = (
# uncond_embeddings.repeat(
# 1,
# num_images_per_prompt,
# 1
# )
# )
# uncond_embeddings = (
# uncond_embeddings.view(
# bs_embed * num_images_per_prompt,
# seq_len,
# -1
# )
# )
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if text_embeddings.shape[1] > model_max_length:
@@ -486,7 +528,8 @@ re_attention = re.compile(
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Parses a string with attention tokens and returns a list of pairs:
text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
@@ -713,6 +756,12 @@ def get_unweighted_text_embeddings(
return text_embeddings
# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
return [[49407 if token is None else token for token in tokens[0]]]
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]],
@@ -804,6 +853,10 @@ def get_weighted_text_embeddings(
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
prompt_tokens = filter_nonetype_tokens(prompt_tokens)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
@@ -816,6 +869,10 @@ def get_weighted_text_embeddings(
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
uncond_tokens = filter_nonetype_tokens(uncond_tokens)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(
uncond_tokens, dtype=torch.long, device="cpu"

View File

@@ -8,6 +8,9 @@ from diffusers import (
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
@@ -38,9 +41,28 @@ def get_schedulers(model_id):
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
)
schedulers[
"DPMSolverMultistep++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
)
schedulers[
"DPMSolverMultistepKarras"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_karras_sigmas=True,
)
schedulers[
"DPMSolverMultistepKarras++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
@@ -62,5 +84,21 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverSinglestep"
] = DPMSolverSinglestepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"KDPM2AncestralDiscrete"
] = KDPM2AncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers

View File

@@ -28,6 +28,7 @@ from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
batch_seeds,
get_path_stem,
get_extended_name,
get_generated_imgs_path,

View File

@@ -5,4 +5,7 @@
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"],
["A photo of a beach, sunset, calm, beautiful landscape, waves, water"],
["(a large body of water with snowy mountains in the background), (fog, foggy, rolling fog), (clouds, cloudy, rolling clouds), dramatic sky and landscape, extraordinary landscape, (beautiful snow capped mountain background), (forest, dirt path)"],
["a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smokes coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"]]

View File

@@ -131,11 +131,32 @@ def load_lower_configs(base_model_id=None):
"v1_4",
"v1_5",
]:
config_name = f"{args.annotation_model}_{version}_{args.max_length}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.max_length}_"
f"{args.precision}_"
f"{device}_"
f"{spec}_"
f"{args.width}x{args.height}.json"
)
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.precision}_"
f"{device}_"
f"{spec}_"
f"{args.width}x{args.height}.json"
)
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.precision}_"
f"{device}_"
f"{spec}.json"
)
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
@@ -180,9 +201,22 @@ def dump_after_mlir(input_mlir, use_winograd):
device, device_spec_args = get_device_args()
if use_winograd:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))"
)
else:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
)
dump_module = ireec.compile_str(
input_mlir,

View File

@@ -19,48 +19,56 @@ p = argparse.ArgumentParser(
)
##############################################################################
### Stable Diffusion Params
# Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
)
p.add_argument(
"-p",
"--prompts",
nargs="+",
default=["cyberpunk forest by Salvador Dali"],
help="text of which images to be generated.",
default=[
"a photo taken of the front of a super-car drifting on a road near "
"mountains at high speeds with smokes coming off the tires, front "
"angle, front point of view, trees in the mountains of the "
"background, ((sharp focus))"
],
help="Text of which images to be generated.",
)
p.add_argument(
"--negative_prompts",
nargs="+",
default=["trees, green"],
help="text you don't want to see in the generated image.",
default=[
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
"blurry, ugly, blur, oversaturated, cropped"
],
help="Text you don't want to see in the generated image.",
)
p.add_argument(
"--img_path",
type=str,
help="Path to the image input for img2img/inpainting",
help="Path to the image input for img2img/inpainting.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
help="The number of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=int,
default=-1,
help="the seed to use. -1 for a random one.",
help="The seed to use. -1 for a random one.",
)
p.add_argument(
@@ -68,7 +76,7 @@ p.add_argument(
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `batch_count`.",
help="The number of inferences to be made in a single `batch_count`.",
)
p.add_argument(
@@ -76,7 +84,7 @@ p.add_argument(
type=int,
default=512,
choices=range(128, 769, 8),
help="the height of the output image.",
help="The height of the output image.",
)
p.add_argument(
@@ -84,84 +92,86 @@ p.add_argument(
type=int,
default=512,
choices=range(128, 769, 8),
help="the width of the output image.",
help="The width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="the value to be used for guidance scaling.",
help="The value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="the value to be used for noise level of upscaler.",
help="The value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
help="Max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--max_embeddings_multiples",
type=int,
default=5,
help="The max multiple length of prompt embeddings compared to the max output length of text encoder.",
help="The max multiple length of prompt embeddings compared to the max "
"output length of text encoder.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="the strength of change applied on the given input image for img2img",
help="The strength of change applied on the given input image for "
"img2img.",
)
##############################################################################
### Stable Diffusion Training Params
# Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model",
help="Directory to save the lora fine tuned model.",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt",
help="Directory containing images that are an example of the prompt.",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The no. of steps to train",
help="The number of steps to train.",
)
##############################################################################
### Inpainting and Outpainting Params
# Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting",
help="Path to the mask image input for inpainting.",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture",
help="If inpaint only masked area or whole picture.",
)
p.add_argument(
@@ -169,7 +179,7 @@ p.add_argument(
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding",
help="Number of pixels for only masked padding.",
)
p.add_argument(
@@ -177,7 +187,7 @@ p.add_argument(
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting",
help="Number of expended pixels for one direction for outpainting.",
)
p.add_argument(
@@ -185,89 +195,92 @@ p.add_argument(
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting",
help="Number of blur pixels for outpainting.",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend left for outpainting",
help="If expend left for outpainting.",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend right for outpainting",
help="If expend right for outpainting.",
)
p.add_argument(
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend top for outpainting",
help="If expend top for outpainting.",
)
p.add_argument(
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend bottom for outpainting",
help="If expend bottom for outpainting.",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0)",
help="Fall-off exponent for outpainting (lower=higher detail) "
"(min=0.0, max=4.0).",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0)",
help="Color variation for outpainting (min=0.0, max=1.0).",
)
##############################################################################
### Model Config and Usage Params
# Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
"--device", type=str, default="vulkan", help="Device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
"--precision", type=str, default="fp16", help="Precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
help="Attempts to load the model from a precompiled flat-buffer "
"and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
help="Saves the compiled flat-buffer to the local directory.",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
help="Download and use the tuned version of the model if available.",
)
p.add_argument(
@@ -281,28 +294,42 @@ p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
"HeunDiscrete].",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
help="Specify the format in which output image is save. "
"Supported options: jpg / png.",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
help="Directory path to save the output images and json.",
)
p.add_argument(
"--batch_count",
type=int,
default=1,
help="number of batch to be generated with random seeds in single execution",
help="Number of batches to be generated with random seeds in "
"single execution.",
)
p.add_argument(
"--repeatable_seeds",
default=False,
action=argparse.BooleanOptionalAction,
help="The seed of the first batch will be used as the rng seed to "
"generate the subsequent seeds for subsequent batches in that run.",
)
p.add_argument(
@@ -316,7 +343,8 @@ p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
"needs to be plugged in.",
)
p.add_argument(
@@ -330,14 +358,15 @@ p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption",
help="Use the accelerate package to reduce cpu memory consumption.",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer)",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
"or an integer).",
)
p.add_argument(
@@ -350,216 +379,242 @@ p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
help="Use standalone LoRA weight using a HF ID or a checkpoint "
"file (~3 MB).",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
help="Runs the quantized version of stable diffusion model. "
"This is currently in experimental phase. "
"Currently, only runs the stable-diffusion-2-1-base model in "
"int8 quantization.",
)
p.add_argument(
"--ondemand",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM",
help="Load and unload models for low VRAM.",
)
p.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
##############################################################################
### IREE - Vulkan supported flags
# IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan",
help="Specify target triple for vulkan.",
)
p.add_argument(
"--iree_metal_target_platform",
type=str,
default="",
help="Specify target triple for metal",
help="Specify target triple for metal.",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info",
help="Profiles vulkan device and collects the .rdc info.",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
help="Flag for setting VMA preferredLargeHeapBlockSize for "
"vulkan device, default is 4G.",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
help="Flag for disabling vulkan validation layers when benchmarking.",
)
##############################################################################
### Misc. Debug and Optimization flags
# Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
help="Use the default scheduler precompiled into the model if available.",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
help="Specify where to save downloaded shark_tank artifacts. "
"If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
help="When enabled call amdllpc to get ISA dumps. "
"Use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
help="Dispatches to return benchmark data on. "
'Use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
help="Directory where you want to store dispatch data "
'generated with "--dispatch_benchmarks".',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
help="Flag for inserting debug frames between iterations "
"for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
help="Flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
help="Flag setting warmup count for CLIP and VAE [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
help="Flag to clear all mlir and vmfb from common locations. "
"Recompiling will take several minutes.",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save a generation information json file with the image.",
help="Flag for whether or not to save a generation information "
"json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
help="Flag for whether or not to save generation information in "
"PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="if import_mlir is True, saves mlir via the debug option in shark importer. Does nothing if import_mlir is false (the default)",
help="If import_mlir is True, saves mlir via the debug option "
"in shark importer. Does nothing if import_mlir is false (the default).",
)
##############################################################################
### Web UI flags
# Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the progress bar animation during image generation",
help="Flag for removing the progress bar animation during "
"image generation.",
)
p.add_argument(
"--ckpt_dir",
type=str,
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="one of: [api, app, web]",
help="One of: [api, app, web].",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
help="Flag for generating a public URL.",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
help="Flag for setting server port.",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for enabling rest API",
help="Flag for enabling rest API.",
)
p.add_argument(
"--output_gallery",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the output gallery tab, and avoid exposing images under --output_dir in the UI",
help="Flag for removing the output gallery tab, and avoid exposing "
"images under --output_dir in the UI.",
)
p.add_argument(
"--output_gallery_followlinks",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether the output gallery tab in the UI should follow symlinks when listing subdirectorys under --output_dir",
help="Flag for whether the output gallery tab in the UI should "
"follow symlinks when listing subdirectories under --output_dir.",
)
##############################################################################
### SD model auto-annotation flags
# SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file",
help="Directory to save the annotated mlir file.",
)
p.add_argument(
@@ -573,33 +628,43 @@ p.add_argument(
"--save_annotation",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file",
help="Save annotated mlir file.",
)
##############################################################################
### SD model auto-tuner flags
# SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file",
help="Directory to save the tuned config file.",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning",
help="Number of iterations for tuning.",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all",
help="Op to be optimized, options are matmul, bmm, conv and all.",
)
##############################################################################
# DocuChat Flags
##############################################################################
p.add_argument(
"--run_docuchat_web",
default=False,
action=argparse.BooleanOptionalAction,
help="Specifies whether the docuchat's web version is running or not.",
)
args, unknown = p.parse_known_args()
if args.import_debug:

View File

@@ -8,7 +8,12 @@ from datetime import datetime as dt
from csv import DictWriter
from pathlib import Path
import numpy as np
from random import randint
from random import (
randint,
seed as seed_random,
getstate as random_getstate,
setstate as random_setstate,
)
import tempfile
import torch
from safetensors.torch import load_file
@@ -32,6 +37,7 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
import requests
from io import BytesIO
from omegaconf import OmegaConf
from cpuinfo import get_cpu_info
def get_extended_name(model_name):
@@ -80,7 +86,9 @@ def _compile_module(shark_module, model_name, extra_args=[]):
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
def get_shark_model(tank_url, model_name, extra_args=None):
if extra_args is None:
extra_args = []
from shark.parser import shark_args
# Set local shark_tank cache directory.
@@ -112,13 +120,15 @@ def compile_through_fx(
save_dir=tempfile.gettempdir(),
debug=False,
generate_vmfb=True,
extra_args=[],
extra_args=None,
base_model_id=None,
model_name=None,
precision=None,
return_mlir=False,
device=None,
):
if extra_args is None:
extra_args = []
if not return_mlir and model_name is not None:
vmfb_path = get_vmfb_path_name(extended_model_name)
if os.path.isfile(vmfb_path):
@@ -204,13 +214,15 @@ def get_device_mapping(driver, key_combination=3):
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
dict: map to possible device names user can input mapped to desired
combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
@@ -224,7 +236,7 @@ def get_device_mapping(driver, key_combination=3):
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
@@ -237,10 +249,12 @@ def get_device_mapping(driver, key_combination=3):
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
"""Gives the appropriate device data (supported name/path) for user
selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
@@ -248,7 +262,8 @@ def map_device_to_name_path(device, key_combination=3):
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
str / tuple: returns the mapping str or tuple of mapping str for
the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
@@ -271,7 +286,8 @@ def set_init_device_flags():
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
f"Found device {device_name}. Using target triple "
f"{args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
@@ -280,9 +296,10 @@ def set_init_device_flags():
if not args.iree_metal_target_platform:
triple = get_metal_target_triple(device_name)
if triple is not None:
args.iree_metal_target_platform = triple
args.iree_metal_target_platform = triple.split("-")[-1]
print(
f"Found device {device_name}. Using target triple {args.iree_metal_target_platform}."
f"Found device {device_name}. Using target triple "
f"{args.iree_metal_target_platform}."
)
elif "cpu" in args.device:
args.device = "cpu"
@@ -380,7 +397,8 @@ def set_init_device_flags():
if args.use_tuned:
print(
f"Using tuned models for {base_model_id}(fp16) on device {args.device}."
f"Using tuned models for {base_model_id}(fp16) on "
f"device {args.device}."
)
else:
print("Tuned models are currently not supported for this setting.")
@@ -438,8 +456,12 @@ def get_available_devices():
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_list.append(f"{device['name']} => {driver_name}://{i}")
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
@@ -531,10 +553,10 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
# EMA weights usually yield higher quality images for inference but
# non-EMA weights have been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
# they want to go for EMA weight extraction or not.
extract_ema = False
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
@@ -555,7 +577,10 @@ def convert_original_vae(vae_checkpoint):
for key in list(vae_checkpoint.keys()):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
config_url = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
"main/configs/stable-diffusion/v1-inference.yaml"
)
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
@@ -676,7 +701,7 @@ def update_lora_weight(model, use_lora, model_name):
# `fetch_and_update_base_model_id` is a resource utility function which
# helps maintaining mapping of the model to run with its base model.
# helps to maintain mapping of the model to run with its base model.
# If `base_model` is "", then this function tries to fetch the base model
# info for the `model_to_run`.
def fetch_and_update_base_model_id(model_to_run, base_model=""):
@@ -693,13 +718,15 @@ def fetch_and_update_base_model_id(model_to_run, base_model=""):
return base_model
elif base_model == "":
return base_model
# Update JSON data to contain an entry mapping model_to_run with base_model.
# Update JSON data to contain an entry mapping model_to_run with
# base_model.
json_data.update(data)
with open(variants_path, "w", encoding="utf-8") as jsonFile:
json.dump(json_data, jsonFile)
# Generate and return a new seed if the provided one is not in the supported range (including -1)
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed):
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
@@ -708,6 +735,28 @@ def sanitize_seed(seed):
return seed
# Generate a set of seeds, using as the first seed of the set,
# optionally using it as the rng seed for subsequent seeds in the set
def batch_seeds(seed, batch_count, repeatable=False):
# use the passed seed as the initial seed of the batch
seeds = [sanitize_seed(seed)]
if repeatable:
# use the initial seed as the rng generator seed
saved_random_state = random_getstate()
seed_random(seed)
# generate the additional seeds
for i in range(1, batch_count):
seeds.append(sanitize_seed(-1))
if repeatable:
# reset the rng back to normal
random_setstate(saved_random_state)
return seeds
# clear all the cached objects to recompile cleanly.
def clear_all():
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
@@ -718,7 +767,8 @@ def clear_all():
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# Temporary workaround of deleting yaml files to incorporate
# diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
@@ -747,7 +797,9 @@ def get_generated_imgs_todays_subdir() -> str:
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info={}):
def save_output_img(output_img, img_seed, extra_info=None):
if extra_info is None:
extra_info = {}
generated_imgs_path = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
@@ -779,17 +831,25 @@ def save_output_img(output_img, img_seed, extra_info={}):
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps: {args.steps},"
f"Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed},"
f"Size: {args.width}x{args.height}, Model: {img_model}, VAE: {img_vae}, LoRA: {img_lora}",
f"{args.prompts[0]}"
f"\nNegative prompt: {args.negative_prompts[0]}"
f"\nSteps: {args.steps},"
f"Sampler: {args.scheduler}, "
f"CFG scale: {args.guidance_scale}, "
f"Seed: {img_seed},"
f"Size: {args.width}x{args.height}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_lora}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
f"[ERROR] Format {args.output_img_format} is not "
f"supported yet. Image saved as png instead."
f"Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append
@@ -832,16 +892,27 @@ def save_output_img(output_img, img_seed, extra_info={}):
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
text_output += (
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
text_output += (
f"\nsteps={args.steps}, "
f"guidance_scale={args.guidance_scale}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={args.height}x{args.width}, "
f"batch_count={args.batch_count}, "
f"batch_size={args.batch_size}, "
f"max_length={args.max_length}"
)
return text_output
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# For stencil, the input image can be of any size, but we need to ensure that
# it conforms with our model constraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.

View File

@@ -0,0 +1,51 @@
# -*- mode: python ; coding: utf-8 -*-
from apps.stable_diffusion.shark_studio_imports import pathex, datas, hiddenimports
binaries = []
block_cipher = None
a = Analysis(
['web\\index.py'],
pathex=pathex,
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name='studio_bundle',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
strip=False,
upx=True,
upx_exclude=[],
name='studio_bundle',
)

View File

@@ -7,7 +7,7 @@ if sys.platform == "darwin":
import torch_mlir
import shutil
import PIL, sentencepiece, transformers # ensures inclusion in pysintaller exe generation
import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
from apps.stable_diffusion.src import args, clear_all
import apps.stable_diffusion.web.utils.global_obj as global_obj
@@ -26,9 +26,10 @@ def launch_app(address):
window = Tk()
# getting screen width and height of display
width = window.winfo_screenwidth()
height = window.winfo_screenheight()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
@@ -49,7 +50,9 @@ if __name__ == "__main__":
upscaler_api,
inpaint_api,
outpaint_api,
llm_chat_api,
)
from fastapi import FastAPI, APIRouter
import uvicorn
@@ -62,8 +65,19 @@ if __name__ == "__main__":
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
# chat APIs needed for compatibility with multiple extensions using OpenAI API
app.add_api_route(
"/v1/chat/completions", llm_chat_api, methods=["post"]
)
app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
app.add_api_route("/completions", llm_chat_api, methods=["post"])
app.add_api_route(
"/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
)
app.include_router(APIRouter())
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
uvicorn.run(app, host="0.0.0.0", port=args.server_port)
sys.exit(0)
# Setup to use shark_tmp for gradio's temporary image files and clear any
@@ -101,6 +115,7 @@ if __name__ == "__main__":
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
h2ogpt_web,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
@@ -226,6 +241,8 @@ if __name__ == "__main__":
upscaler_status,
]
)
with gr.TabItem(label="DocuChat(Experimental)", id=9):
h2ogpt_web.render()
# send to buttons
register_button_click(

View File

@@ -74,7 +74,11 @@ from apps.stable_diffusion.web.ui.model_manager import (
modelmanager_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat
from apps.stable_diffusion.web.ui.stablelm_ui import (
stablelm_chat,
llm_chat_api,
)
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_web
from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_web,
outputgallery_tab_select,

View File

@@ -117,16 +117,12 @@ body {
padding: 0 var(--size-4) !important;
}
.container {
background-color: black !important;
padding-top: var(--size-5) !important;
}
#ui_title {
padding: var(--size-2) 0 0 var(--size-1);
}
#top_logo {
color: transparent;
background-color: transparent;
border-radius: 0 !important;
border: 0;
@@ -227,10 +223,19 @@ footer {
}
/* Hide the download icon from the nod logo */
#top_logo .download {
#top_logo button {
display: none;
}
/* workarounds for container=false not currently working for dropdowns */
.dropdown_no_container {
padding: 0 !important;
}
#output_subdir_container :first-child {
border: none;
}
/* reduced animation load when generating */
.generating {
animation-play-state: paused !important;
@@ -247,7 +252,7 @@ footer {
line-height: var(--line-xs)
}
#output_refresh_button {
.output_icon_button {
max-width: 30px;
align-self: end;
padding-bottom: 8px;

View File

@@ -0,0 +1,249 @@
import gradio as gr
import torch
import os
from pathlib import Path
from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
from apps.language_models.langchain.enums import (
DocumentChoices,
LangChainAction,
)
import apps.language_models.langchain.gen as gen
from apps.stable_diffusion.src import args
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
sharkModel = 0
h2ogpt_model = 0
# NOTE: Each `model_name` should have its own start message
start_message = """
SHARK DocuChat
Chat with an AI, contextualized with provided files.
"""
def create_prompt(history):
system_message = start_message
conversation = "".join(["".join([item[0], item[1]]) for item in history])
msg = system_message + conversation
msg = msg.strip()
return msg
def chat(curr_system_message, history, device, precision):
args.run_docuchat_web = True
global h2ogpt_model
global h2ogpt_tokenizer
global model_state
global langchain
global userpath_selector
if h2ogpt_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu"
elif "task" in device:
device = "cpu"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
args.device = device
args.precision = precision
from apps.language_models.langchain.gen import Langchain
langchain = Langchain(device, precision)
h2ogpt_model, h2ogpt_tokenizer, _ = langchain.get_model(
load_8bit=True
if device == "cuda"
else False, # load model in 4bit if device is cuda to save memory
load_gptq="",
use_safetensors=False,
infer_devices=True,
device=device,
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
inference_server="",
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
lora_weights="",
gpu_id=0,
reward_type=None,
local_files_only=False,
resume_download=True,
use_auth_token=False,
trust_remote_code=True,
offload_folder=None,
compile_model=False,
verbose=False,
)
model_state = dict(
model=h2ogpt_model,
tokenizer=h2ogpt_tokenizer,
device=device,
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
lora_weights="",
inference_server="",
prompt_type=None,
prompt_dict=None,
)
prompt = create_prompt(history)
output = langchain.evaluate(
model_state=model_state,
my_db_state=None,
instruction=prompt,
iinput="",
context="",
stream_output=True,
prompt_type="prompt_answer",
prompt_dict={
"promptA": "",
"promptB": "",
"PreInstruct": "<|prompt|>",
"PreInput": None,
"PreResponse": "<|answer|>",
"terminate_response": [
"<|prompt|>",
"<|answer|>",
"<|endoftext|>",
],
"chat_sep": "<|endoftext|>",
"chat_turn_sep": "<|endoftext|>",
"humanstr": "<|prompt|>",
"botstr": "<|answer|>",
"generates_leading_space": False,
},
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=1,
max_new_tokens=256,
min_new_tokens=0,
early_stopping=False,
max_time=180,
repetition_penalty=1.07,
num_return_sequences=1,
do_sample=False,
chat=True,
instruction_nochat=prompt,
iinput_nochat="",
langchain_mode="UserData",
langchain_action=LangChainAction.QUERY.value,
top_k_docs=3,
chunk=True,
chunk_size=512,
document_choice=[DocumentChoices.All_Relevant.name],
concurrency_count=1,
memory_restriction_level=2,
raise_generate_gpu_exceptions=False,
chat_context="",
use_openai_embedding=False,
use_openai_model=False,
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
db_type="chroma",
n_jobs=-1,
first_para=False,
max_max_time=60 * 2,
model_state0=model_state,
model_lock=True,
user_path=userpath_selector.value,
)
for partial_text in output:
history[-1][1] = partial_text["response"]
yield history
return history
with gr.Blocks(title="H2OGPT") as h2ogpt_web:
with gr.Row():
supported_devices = available_devices
enabled = len(supported_devices) > 0
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
if enabled
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"int4",
"int8",
"fp16",
"fp32",
],
visible=True,
)
userpath_selector = gr.Textbox(
label="Document Directory",
value=str(
os.path.abspath("apps/language_models/langchain/user_path/")
),
interactive=True,
container=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
system_msg = gr.Textbox(
start_message, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, device, precision],
outputs=[chatbot],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, device, precision],
outputs=[chatbot],
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)

View File

@@ -66,6 +66,7 @@ def img2img_inf(
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -104,7 +105,8 @@ def img2img_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -132,7 +134,8 @@ def img2img_inf(
image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
f"Shark schedulers are not supported. Switching to EulerDiscrete "
f"scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
@@ -228,12 +231,11 @@ def img2img_inf(
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
extra_info = {"STRENGTH": strength}
text_output = ""
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
@@ -244,7 +246,7 @@ def img2img_inf(
steps,
strength,
guidance_scale,
img_seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
@@ -252,9 +254,10 @@ def img2img_inf(
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
@@ -263,7 +266,7 @@ def img2img_inf(
else:
save_output_img(
out_imgs[0],
img_seed,
seeds[current_batch],
extra_info,
)
generated_imgs.extend(out_imgs)
@@ -308,7 +311,9 @@ def img2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = img2img_inf(
@@ -340,6 +345,7 @@ def img2img_api(
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Converts generator type to subscriptable
@@ -362,13 +368,21 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
i2i_model_info = (str(get_custom_model_path())).replace(
"\\", "\n\\"
)
i2i_model_info = f"Custom Model Path: {i2i_model_info}"
img2img_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=i2i_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -379,13 +393,23 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
img2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
i2i_vae_info = (str(get_custom_model_path("vae"))).replace(
"\\", "\n\\"
)
i2i_vae_info = f"VAE Path: {i2i_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=i2i_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -397,13 +421,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
@@ -412,7 +436,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
source="upload",
tool="sketch",
type="pil",
).style(height=300)
height=300,
)
with gr.Accordion(label="Stencil Options", open=False):
with gr.Row():
@@ -475,15 +500,24 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
i2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=i2i_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -533,16 +567,18 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
strength = gr.Slider(
0,
1,
value=args.strength,
step=0.01,
label="Denoising Strength",
)
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Column(scale=3):
strength = gr.Slider(
0,
1,
value=args.strength,
step=0.01,
label="Denoising Strength",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
@@ -566,6 +602,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Batch Count",
interactive=True,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
batch_size = gr.Slider(
1,
4,
@@ -575,7 +616,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -587,16 +627,15 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -604,9 +643,12 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=2,
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -648,6 +690,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
lora_weights,
lora_hf_id,
ondemand,
repeatable_seeds,
],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -64,6 +64,7 @@ def inpaint_inf(
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: int,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -92,7 +93,8 @@ def inpaint_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -179,14 +181,12 @@ def inpaint_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
@@ -199,26 +199,27 @@ def inpaint_inf(
inpaint_full_res_padding,
steps,
guidance_scale,
img_seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
save_output_img(out_imgs[0], seeds[current_batch])
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Inpaint", i + 1, batch_count, batch_size
"Inpaint", current_batch + 1, batch_count, batch_size
)
return generated_imgs, text_output
@@ -258,7 +259,9 @@ def inpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
init_image = decode_base64_to_image(InputData["image"])
mask = decode_base64_to_image(InputData["mask"])
@@ -289,6 +292,7 @@ def inpaint_api(
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Converts generator type to subscriptable
@@ -311,13 +315,23 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
inpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
inpaint_model_info = (
f"Custom Model Path: {inpaint_model_info}"
)
inpaint_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=inpaint_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -330,13 +344,23 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
)
inpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
inpaint_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
inpaint_vae_info = f"VAE Path: {inpaint_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=inpaint_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -348,13 +372,13 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
@@ -363,19 +387,29 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
source="upload",
tool="sketch",
type="pil",
).style(height=350)
height=350,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
inpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
inpaint_lora_info = f"LoRA Path: {inpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=inpaint_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -465,6 +499,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Batch Count",
interactive=True,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
batch_size = gr.Slider(
1,
4,
@@ -474,7 +513,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -486,16 +524,15 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -503,9 +540,12 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -548,6 +588,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
lora_weights,
lora_hf_id,
ondemand,
repeatable_seeds,
],
outputs=[inpaint_gallery, std_output, inpaint_status],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -24,15 +24,25 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
# janky fix for overflowing text
train_lora_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
train_lora_model_info = (
f"Custom Model Path: {train_lora_model_info}"
)
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=train_lora_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -43,22 +53,33 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models "
"dropdown on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
# janky fix for overflowing text
train_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
train_lora_info = f"LoRA Path: {train_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights to initialize weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA weights to initialize weights",
info=train_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use a "
"standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID to initialize weights",
lines=3,
@@ -74,7 +95,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
with gr.Accordion(label="Advanced Options", open=False):

View File

@@ -19,7 +19,10 @@ def get_hf_list(num_of_models=20):
def get_civit_list(num_of_models=50):
path = f"https://civitai.com/api/v1/models?limit={num_of_models}&types=Checkpoint"
path = (
f"https://civitai.com/api/v1/models?limit="
f"{num_of_models}&types=Checkpoint"
)
headers = {"Content-Type": "application/json"}
raw_json = requests.get(path, headers=headers).json()
models = list(raw_json.items())[0][1]
@@ -79,7 +82,7 @@ with gr.Blocks() as model_web:
type="value",
label="Model Source",
)
model_numebr = gr.Slider(
model_number = gr.Slider(
1,
100,
value=10,
@@ -111,9 +114,9 @@ with gr.Blocks() as model_web:
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
def get_model_list(model_source, model_numebr):
def get_model_list(model_source, model_number):
if model_source == "Hugging Face":
hf_model_list = get_hf_list(model_numebr)
hf_model_list = get_hf_list(model_number)
models = []
for model in hf_model_list:
# TODO: add model info
@@ -124,7 +127,7 @@ with gr.Blocks() as model_web:
gr.Row.update(visible=True),
)
elif model_source == "Civitai":
civit_model_list = get_civit_list(model_numebr)
civit_model_list = get_civit_list(model_number)
models = []
for model in civit_model_list:
image = get_image_from_model(model)
@@ -148,7 +151,7 @@ with gr.Blocks() as model_web:
get_model_btn.click(
fn=get_model_list,
inputs=[model_source, model_numebr],
inputs=[model_source, model_number],
outputs=[
hf_models,
civit_models,

View File

@@ -29,7 +29,6 @@ from apps.stable_diffusion.src.utils import (
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -65,6 +64,7 @@ def outpaint_inf(
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -92,7 +92,8 @@ def outpaint_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -177,8 +178,7 @@ def outpaint_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
left = True if "left" in directions else False
right = True if "right" in directions else False
@@ -186,9 +186,7 @@ def outpaint_inf(
bottom = True if "down" in directions else False
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
@@ -206,26 +204,27 @@ def outpaint_inf(
width,
steps,
guidance_scale,
img_seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
save_output_img(out_imgs[0], seeds[current_batch])
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Outpaint", i + 1, batch_count, batch_size
"Outpaint", current_batch + 1, batch_count, batch_size
)
return generated_imgs, text_output, ""
@@ -265,7 +264,9 @@ def outpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = outpaint_inf(
@@ -298,6 +299,7 @@ def outpaint_api(
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Convert Generator to Subscriptable
@@ -320,13 +322,23 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
outpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
outpaint_model_info = (
f"Custom Model Path: {outpaint_model_info}"
)
outpaint_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=outpaint_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -339,13 +351,23 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
)
outpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
outpaint_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
outpaint_vae_info = f"VAE Path: {outpaint_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=outpaint_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -357,31 +379,42 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
outpaint_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
label="Input Image",
type="pil",
height=300,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
outpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
outpaint_lora_info = f"LoRA Path: {outpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=outpaint_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -493,6 +526,12 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Batch Count",
interactive=True,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
batch_size = gr.Slider(
1,
4,
@@ -502,7 +541,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -514,16 +552,15 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -531,9 +568,12 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -576,6 +616,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
lora_weights,
lora_hf_id,
ondemand,
repeatable_seeds,
],
outputs=[outpaint_gallery, std_output, outpaint_status],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -1,6 +1,8 @@
import glob
import gradio as gr
import os
import subprocess
import sys
from PIL import Image
from apps.stable_diffusion.src import args
@@ -38,14 +40,14 @@ def output_subdirs() -> list[str]:
)
]
# It is less confusing to always including the subdir that will take any images generated
# today even if it doesn't exist yet
# It is less confusing to always including the subdir that will take any
# images generated today even if it doesn't exist yet
if get_generated_imgs_todays_subdir() not in relative_paths:
relative_paths.append(get_generated_imgs_todays_subdir())
# sort subdirectories so that that the date named ones we probably created in this or
# previous sessions come first, sorted with the most recent first. Other subdirs are listed
# after.
# sort subdirectories so that the date named ones we probably
# created in this or previous sessions come first, sorted with the most
# recent first. Other subdirs are listed after.
generated_paths = sorted(
[path for path in relative_paths if path.isnumeric()], reverse=True
)
@@ -66,7 +68,8 @@ with gr.Blocks() as outputgallery_web:
nod_logo = Image.open(nodlogo_loc)
with gr.Row(elem_id="outputgallery_gallery"):
# needed to workaround gradio issue: https://github.com/gradio-app/gradio/issues/2907
# needed to workaround gradio issue:
# https://github.com/gradio-app/gradio/issues/2907
dev_null = gr.Textbox("", visible=False)
gallery_files = gr.State(value=[])
@@ -88,33 +91,56 @@ with gr.Blocks() as outputgallery_web:
value=gallery_files.value,
visible=False,
show_label=True,
).style(columns=4)
columns=2,
)
with gr.Column(scale=4):
with gr.Box():
with gr.Row():
with gr.Column(scale=16, min_width=160):
with gr.Column(
scale=15,
min_width=160,
elem_id="output_subdir_container",
):
subdirectories = gr.Dropdown(
label=f"Subdirectories of {output_dir}",
type="value",
choices=subdirectory_paths.value,
value="",
interactive=True,
).style(container=False)
elem_classes="dropdown_no_container",
)
with gr.Column(
scale=1, min_width=32, elem_id="output_refresh_button"
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
open_subdir = gr.Button(
variant="secondary",
value="\U0001F5C1", # unicode open folder
interactive=False,
size="sm",
)
with gr.Column(
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
refresh = gr.Button(
variant="secondary",
value="\u21BB", # unicode clockwise arrow circle
).style(size="sm")
size="sm",
)
image_columns = gr.Slider(
label="Columns shown", value=4, minimum=1, maximum=16, step=1
)
outputgallery_filename = gr.Textbox(
label="Filename", value="None", interactive=False
).style(show_copy_button=True)
label="Filename",
value="None",
interactive=False,
show_copy_button=True,
)
with gr.Accordion(
label="Parameter Information", open=False
@@ -133,31 +159,36 @@ with gr.Blocks() as outputgallery_web:
value="Txt2Img",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
size="sm",
)
outputgallery_sendto_img2img = gr.Button(
value="Img2Img",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
size="sm",
)
outputgallery_sendto_inpaint = gr.Button(
value="Inpaint",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
size="sm",
)
outputgallery_sendto_outpaint = gr.Button(
value="Outpaint",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
size="sm",
)
outputgallery_sendto_upscaler = gr.Button(
value="Upscaler",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
size="sm",
)
# --- Event handlers
@@ -191,17 +222,32 @@ with gr.Blocks() as outputgallery_web:
),
]
def on_open_subdir(subdir):
subdir_path = os.path.normpath(os.path.join(output_dir, subdir))
if os.path.isdir(subdir_path):
if sys.platform == "linux":
subprocess.run(["xdg-open", subdir_path])
elif sys.platform == "darwin":
subprocess.run(["open", subdir_path])
elif sys.platform == "win32":
os.startfile(subdir_path)
def on_refresh(current_subdir: str) -> list:
# get an up to date subdirectory list
# get an up-to-date subdirectory list
refreshed_subdirs = output_subdirs()
# get the images using either the current subdirectory or the most recent valid one
# get the images using either the current subdirectory or the most
# recent valid one
new_subdir = (
current_subdir
if current_subdir in refreshed_subdirs
else refreshed_subdirs[0]
)
new_images = outputgallery_filenames(new_subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, new_subdir)}"
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, new_subdir)}"
)
return [
gr.Dropdown.update(
@@ -220,17 +266,22 @@ with gr.Blocks() as outputgallery_web:
]
def on_new_image(subdir, subdir_paths, status) -> list:
# prevent error triggered when an image generates before the tab has even been selected
# prevent error triggered when an image generates before the tab
# has even been selected
subdir_paths = (
subdir_paths
if len(subdir_paths) > 0
else [get_generated_imgs_todays_subdir()]
)
# only update if the current subdir is the most recent one as new images only go there
# only update if the current subdir is the most recent one as
# new images only go there
if subdir_paths[0] == subdir:
new_images = outputgallery_filenames(subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, subdir)} - {status}"
)
return [
new_images,
@@ -245,19 +296,27 @@ with gr.Blocks() as outputgallery_web:
),
]
else:
# otherwise change nothing, (only untyped gradio gr.update() does this)
# otherwise change nothing,
# (only untyped gradio gr.update() does this)
return [gr.update(), gr.update(), gr.update()]
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
# evt.index is an index into the full list of filenames for the current subdirectory
# evt.index is an index into the full list of filenames for
# the current subdirectory
filename = images[evt.index]
params = displayable_metadata(filename)
if params:
return [
filename,
list(map(list, params["parameters"].items())),
]
if params["source"] == "missing":
return [
"Could not find this image file, refresh the gallery and update the images",
[["Status", "File missing"]],
]
else:
return [
filename,
list(map(list, params["parameters"].items())),
]
return [
filename,
@@ -267,7 +326,8 @@ with gr.Blocks() as outputgallery_web:
def on_outputgallery_filename_change(filename: str) -> list:
exists = filename != "None" and os.path.exists(filename)
return [
# disable or enable each of the sendto button based on whether an image is selected
# disable or enable each of the sendto button based on whether
# an image is selected
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
@@ -276,17 +336,23 @@ with gr.Blocks() as outputgallery_web:
gr.Button.update(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh to populate
# the subdirectory select box and the images from the most recent subdirectory.
# The time first our tab is selected we need to do an initial refresh
# to populate the subdirectory select box and the images from the most
# recent subdirectory.
#
# We do it at this point rather than setting this up in the controls' definitions
# as when you refresh the browser you always get what was *initially* set, which
# won't include any new subdirectories or images that might have created since
# the application was started. Doing it this way means a browser refresh/reload
# always gets the most up to date data.
def on_select_tab(subdir_paths):
# We do it at this point rather than setting this up in the controls'
# definitions as when you refresh the browser you always get what was
# *initially* set, which won't include any new subdirectories or images
# that might have created since the application was started. Doing it
# this way means a browser refresh/reload always gets the most
# up-to-date data.
def on_select_tab(subdir_paths, request: gr.Request):
local_client = request.headers["host"].startswith(
"127.0.0.1:"
) or request.headers["host"].startswith("localhost:")
if len(subdir_paths) == 0:
return on_refresh("")
return on_refresh("") + [gr.update(interactive=local_client)]
else:
return (
# Change nothing, (only untyped gr.update() does this)
@@ -295,13 +361,14 @@ with gr.Blocks() as outputgallery_web:
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
# Unfortunately as of gradio 3.22.0 gr.update against Galleries doesn't support
# things set with .style, nor the elem_classes kwarg so we have to directly set
# things up via JavaScript if we want the client to take notice of any of our
# changes to the number of columns after it decides to put them back to the
# original number when we change something
# Unfortunately as of gradio 3.34.0 gr.update against Galleries doesn't
# support things set with .style, nor the elem_classes kwarg, so we have
# to directly set things up via JavaScript if we want the client to take
# notice of our changes to the number of columns after it decides to put
# them back to the original number when we change something
def js_set_columns_in_browser(timeout_length):
return f"""
(new_cols) => {{
@@ -318,32 +385,36 @@ with gr.Blocks() as outputgallery_web:
# --- Wire handlers up to the actions
# - Many actions reset the number of columns shown in the gallery on the browser end,
# so we have to set them back to what we think they should be after the initial
# action.
# - None of the actions on this tab trigger inference, and we want the user to be able
# to do them whilst other tabs have ongoing inference running. Waiting in the queue
# behind inference jobs would mean the UI can't fully respond until the inference tasks
# complete, hence queue=False on all of these.
# Many actions reset the number of columns shown in the gallery on the
# browser end, so we have to set them back to what we think they should
# be after the initial action.
#
# None of the actions on this tab trigger inference, and we want the
# user to be able to do them whilst other tabs have ongoing inference
# running. Waiting in the queue behind inference jobs would mean the UI
# can't fully respond until the inference tasks complete,
# hence queue=False on all of these.
set_gallery_columns_immediate = dict(
fn=None,
inputs=[image_columns],
# gradio blanks the UI on Chrome on Linux on gallery select if I don't put an output here
# gradio blanks the UI on Chrome on Linux on gallery select if
# I don't put an output here
outputs=[dev_null],
_js=js_set_columns_in_browser(0),
queue=False,
)
# setting columns after selecting a gallery item needs a real timeout length for the
# number of columns to actually be applied. Not really sure why, maybe something has
# to finish animating?
# setting columns after selecting a gallery item needs a real
# timeout length for the number of columns to actually be applied.
# Not really sure why, maybe something has to finish animating?
set_gallery_columns_delayed = dict(
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
)
# clearing images when we need to completely change what's in the gallery avoids current
# images being shown replacing piecemeal and prevents weirdness and errors if the user
# selects an image during the replacement phase.
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
# replacement phase.
clear_gallery = dict(
fn=on_clear_gallery,
inputs=None,
@@ -360,6 +431,10 @@ with gr.Blocks() as outputgallery_web:
queue=False,
).then(**set_gallery_columns_immediate)
open_subdir.click(
on_open_subdir, inputs=[subdirectories], queue=False
).then(**set_gallery_columns_immediate)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
@@ -398,6 +473,7 @@ with gr.Blocks() as outputgallery_web:
gallery_files,
gallery,
logo,
open_subdir,
],
queue=False,
).then(**set_gallery_columns_immediate)

View File

@@ -6,13 +6,7 @@ from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
from datetime import datetime as dt
def user(message, history):
@@ -24,22 +18,112 @@ sharkModel = 0
sharded_model = 0
vicuna_model = 0
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
past_key_values = None
model_map = {
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
}
def chat(curr_system_message, history, model, device, precision):
print(f"In chat for {model}")
global sharded_model
global past_key_values
global vicuna_model
if "vicuna" in model:
from apps.language_models.src.pipelines.vicuna_pipeline import (
Vicuna,
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2_7b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_70b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"StableLM": (
"<|SYSTEM|># StableLM Tuned (Alpha version)"
"\n- StableLM is a helpful and harmless open-source AI language model "
"developed by StabilityAI."
"\n- StableLM is excited to be able to help the user, but will refuse "
"to do anything that could be considered harmful to the user."
"\n- StableLM is more than just an information source, StableLM is also "
"able to write poetry, short stories, and make jokes."
"\n- StableLM will refuse to participate in anything that "
"could harm a human."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"vicuna1p3": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"codegen": "",
}
def create_prompt(model_name, history):
system_message = start_message[model_name]
if model_name in [
"StableLM",
"vicuna",
"vicuna1p3",
"llama2_7b",
"llama2_70b",
]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
curr_system_message = start_message_vicuna
msg = system_message + conversation
msg = msg.strip()
return msg
def set_vicuna_model(model):
global vicuna_model
vicuna_model = model
# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision, cli=True):
global past_key_values
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
if model_name in [
"vicuna",
"vicuna1p3",
"codegen",
"llama2_7b",
"llama2_70b",
]:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
from apps.stable_diffusion.src import args
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
@@ -51,22 +135,19 @@ def chat(curr_system_message, history, model, device, precision):
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = Vicuna(
"vicuna",
hf_model_path=model,
max_toks = 128 if model_name == "codegen" else 512
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
precision=precision,
max_num_tokens=max_toks,
)
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
prompt = messages.strip()
print("prompt = ", prompt)
prompt = create_prompt(model_name, history)
for partial_text in vicuna_model.generate(prompt):
for partial_text in vicuna_model.generate(prompt, cli=cli):
history[-1][1] = partial_text
yield history
@@ -81,46 +162,144 @@ def chat(curr_system_message, history, model, device, precision):
if sharkModel == 0:
# max_new_tokens=512
shark_slm = SharkStableLM(
"StableLM"
model_name
) # pass elements from UI as required
# Construct the input message string for the model by concatenating the current system message and conversation history
# Construct the input message string for the model by concatenating the
# current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
curr_system_message = start_message
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
generate_kwargs = dict(prompt=messages)
prompt = create_prompt(model_name, history)
generate_kwargs = dict(prompt=prompt)
words_list = shark_slm.generate(**generate_kwargs)
partial_text = ""
for new_text in words_list:
# print(new_text)
print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
# Yield an empty string to clean up the message textbox and the updated
# conversation history
yield history
return words_list
def llm_chat_api(InputData: dict):
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
# if is_chat_completion_api:
# print(f"message -> role : {InputData['messages'][0]['role']}")
# print(f"message -> content : {InputData['messages'][0]['content']}")
# else:
# print(f"prompt : {InputData['prompt']}")
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
global vicuna_model
model_name = (
InputData["model"] if "model" in InputData.keys() else "codegen"
)
model_path = model_map[model_name]
device = "cpu-task"
precision = "fp16"
max_toks = (
None
if "max_tokens" not in InputData.keys()
else InputData["max_tokens"]
)
if max_toks is None:
max_toks = 128 if model_name == "codegen" else 512
# make it working for codegen first
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
)
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = create_prompt(
model_name, [(InputData["messages"][0]["content"], "")]
)
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
res = vicuna_model.generate(prompt)
res_op = None
for op in res:
res_op = op
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion"
if is_chat_completion_api
else "text_completion",
"created": int(end_time),
"choices": choices,
}
with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model_choices = list(
map(lambda x: f"{x[0]: <10} => {x[1]}", model_map.items())
)
model = gr.Dropdown(
label="Select Model",
value="TheBloke/vicuna-7B-1.1-HF",
choices=[
"stabilityai/stablelm-tuned-alpha-3b",
"TheBloke/vicuna-7B-1.1-HF",
],
value=model_choices[0],
choices=model_choices,
)
supported_devices = available_devices
enabled = len(supported_devices) > 0
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
@@ -140,7 +319,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
],
visible=True,
)
chatbot = gr.Chatbot().style(height=500)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
@@ -148,7 +327,8 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
).style(container=False)
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)

View File

@@ -61,6 +61,7 @@ def txt2img_inf(
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -87,7 +88,8 @@ def txt2img_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -176,12 +178,10 @@ def txt2img_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
@@ -190,26 +190,27 @@ def txt2img_inf(
width,
steps,
guidance_scale,
img_seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
save_output_img(out_imgs[0], seeds[current_batch])
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Text-to-Image", i + 1, batch_count, batch_size
"Text-to-Image", current_batch + 1, batch_count, batch_size
)
return generated_imgs, text_output, ""
@@ -238,7 +239,9 @@ def txt2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
res = txt2img_inf(
InputData["prompt"],
@@ -264,6 +267,7 @@ def txt2img_api(
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Convert Generator to Subscriptable
@@ -286,15 +290,25 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
# janky fix for overflowing text
t2i_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
t2i_model_info = (
f"Custom Model Path: {t2i_model_info}"
)
txt2img_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=t2i_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -305,13 +319,21 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the dropdown on the left and enter model ID here",
placeholder="Select 'None' in the dropdown "
"on the left and enter model ID here.",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL.",
lines=3,
)
# janky fix for overflowing text
t2i_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
t2i_vae_info = f"VAE Path: {t2i_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"VAE Models",
info=t2i_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -332,26 +354,35 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
t2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
t2i_lora_info = f"LoRA Path: {t2i_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=t2i_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -364,7 +395,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
value=args.scheduler,
choices=scheduler_list,
)
with gr.Group():
with gr.Column():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
@@ -409,16 +440,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
@@ -443,7 +476,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Batch Size",
interactive=True,
)
stop_batch = gr.Button("Stop Batch")
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -455,17 +491,15 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
@@ -480,9 +514,12 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -522,6 +559,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
lora_weights,
lora_hf_id,
ondemand,
repeatable_seeds,
],
outputs=[txt2img_gallery, std_output, txt2img_status],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -57,6 +57,7 @@ def upscaler_inf(
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -88,7 +89,8 @@ def upscaler_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -175,12 +177,10 @@ def upscaler_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
extra_info = {"NOISE LEVEL": noise_level}
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
low_res_img = image
high_res_img = Image.new("RGB", (height * 4, width * 4))
@@ -197,7 +197,7 @@ def upscaler_inf(
steps,
noise_level,
guidance_scale,
img_seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
@@ -212,27 +212,40 @@ def upscaler_inf(
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
text_output += (
f"\nsteps={steps}, "
f"noise_level={noise_level}, "
f"guidance_scale={guidance_scale}, "
f"seed={seeds[:current_batch + 1]}"
)
text_output += (
f"\ninput size={height}x{width}, "
f"output size={height*4}x{width*4}, "
f"batch_count={batch_count}, "
f"batch_size={batch_size}, "
f"max_length={args.max_length}\n"
)
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(high_res_img, img_seed, extra_info)
save_output_img(high_res_img, seeds[current_batch], extra_info)
generated_imgs.append(high_res_img)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log, status_label(
yield generated_imgs, text_output, status_label(
"Upscaler", current_batch + 1, batch_count, batch_size
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output, ""
@@ -270,7 +283,9 @@ def upscaler_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = upscaler_inf(
@@ -299,6 +314,7 @@ def upscaler_api(
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Converts generator type to subscriptable
res = next(res)
@@ -320,13 +336,23 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
upscaler_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
upscaler_model_info = (
f"Custom Model Path: {upscaler_model_info}"
)
upscaler_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=upscaler_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -339,13 +365,23 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
)
upscaler_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
upscaler_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
upscaler_vae_info = f"VAE Path: {upscaler_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=upscaler_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -357,31 +393,42 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
upscaler_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
label="Input Image",
type="pil",
height=300,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
upscaler_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=upscaler_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -472,6 +519,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Batch Count",
interactive=True,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
batch_size = gr.Slider(
1,
4,
@@ -481,7 +533,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
@@ -493,16 +544,15 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -510,9 +560,12 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -552,6 +605,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
lora_weights,
lora_hf_id,
ondemand,
repeatable_seeds,
],
outputs=[upscaler_gallery, std_output, upscaler_status],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -39,8 +39,16 @@ scheduler_list_cpu_only = [
"LMSDiscrete",
"KDPM2Discrete",
"DPMSolverMultistep",
"DPMSolverMultistep++",
"DPMSolverMultistepKarras",
"DPMSolverMultistepKarras++",
"EulerDiscrete",
"EulerAncestralDiscrete",
"DEISMultistep",
"KDPM2AncestralDiscrete",
"DPMSolverSinglestep",
"DDPM",
"HeunDiscrete",
]
scheduler_list = scheduler_list_cpu_only + [
"SharkEulerDiscrete",
@@ -50,6 +58,7 @@ predefined_models = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"xzuyn/PhotoMerge",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
@@ -58,6 +67,7 @@ predefined_models = [
predefined_paint_models = [
"runwayml/stable-diffusion-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
"xzuyn/PhotoMerge-inpainting",
]
predefined_upscaler_models = [
"stabilityai/stable-diffusion-x4-upscaler",
@@ -79,7 +89,8 @@ def create_custom_models_folders():
else:
if not os.path.isdir(args.ckpt_dir):
sys.exit(
f"Invalid --ckpt_dir argument, {args.ckpt_dir} folder does not exists."
f"Invalid --ckpt_dir argument, "
f"{args.ckpt_dir} folder does not exists."
)
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)

View File

@@ -8,6 +8,9 @@ from .format import compact, humanize
def displayable_metadata(image_filename: str) -> dict:
if not os.path.isfile(image_filename):
return {"source": "missing", "parameters": {}}
pil_image = Image.open(image_filename)
# we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads,

View File

@@ -91,6 +91,18 @@ def compact(metadata: dict) -> dict:
else:
result["Hires resize"] = f"{hires_y}x{hires_x}"
# remove VAE if it exists and is empty
if (result.keys() & {"VAE"}) and (
not result["VAE"] or result["VAE"] == "None"
):
result.pop("VAE")
# remove LoRA if it exists and is empty
if (result.keys() & {"LoRA"}) and (
not result["LoRA"] or result["LoRA"] == "None"
):
result.pop("LoRA")
return result

View File

@@ -0,0 +1,88 @@
ARG IMAGE_NAME
FROM ${IMAGE_NAME}:12.2.0-runtime-ubuntu22.04 as base
ENV NV_CUDA_LIB_VERSION "12.2.0-1"
FROM base as base-amd64
ENV NV_CUDA_CUDART_DEV_VERSION 12.2.53-1
ENV NV_NVML_DEV_VERSION 12.2.81-1
ENV NV_LIBCUSPARSE_DEV_VERSION 12.1.1.53-1
ENV NV_LIBNPP_DEV_VERSION 12.1.1.14-1
ENV NV_LIBNPP_DEV_PACKAGE libnpp-dev-12-2=${NV_LIBNPP_DEV_VERSION}
ENV NV_LIBCUBLAS_DEV_VERSION 12.2.1.16-1
ENV NV_LIBCUBLAS_DEV_PACKAGE_NAME libcublas-dev-12-2
ENV NV_LIBCUBLAS_DEV_PACKAGE ${NV_LIBCUBLAS_DEV_PACKAGE_NAME}=${NV_LIBCUBLAS_DEV_VERSION}
ENV NV_CUDA_NSIGHT_COMPUTE_VERSION 12.2.0-1
ENV NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE cuda-nsight-compute-12-2=${NV_CUDA_NSIGHT_COMPUTE_VERSION}
ENV NV_NVPROF_VERSION 12.2.60-1
ENV NV_NVPROF_DEV_PACKAGE cuda-nvprof-12-2=${NV_NVPROF_VERSION}
FROM base as base-arm64
ENV NV_CUDA_CUDART_DEV_VERSION 12.2.53-1
ENV NV_NVML_DEV_VERSION 12.2.81-1
ENV NV_LIBCUSPARSE_DEV_VERSION 12.1.1.53-1
ENV NV_LIBNPP_DEV_VERSION 12.1.1.14-1
ENV NV_LIBNPP_DEV_PACKAGE libnpp-dev-12-2=${NV_LIBNPP_DEV_VERSION}
ENV NV_LIBCUBLAS_DEV_PACKAGE_NAME libcublas-dev-12-2
ENV NV_LIBCUBLAS_DEV_VERSION 12.2.1.16-1
ENV NV_LIBCUBLAS_DEV_PACKAGE ${NV_LIBCUBLAS_DEV_PACKAGE_NAME}=${NV_LIBCUBLAS_DEV_VERSION}
ENV NV_CUDA_NSIGHT_COMPUTE_VERSION 12.2.0-1
ENV NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE cuda-nsight-compute-12-2=${NV_CUDA_NSIGHT_COMPUTE_VERSION}
FROM base-${TARGETARCH}
ARG TARGETARCH
LABEL maintainer "SHARK<stdin@nod.com>"
# Register the ROCM package repository, and install rocm-dev package
ARG ROCM_VERSION=5.6
ARG AMDGPU_VERSION=5.6
ARG APT_PREF
RUN echo "$APT_PREF" > /etc/apt/preferences.d/rocm-pin-600
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends ca-certificates curl libnuma-dev gnupg \
&& curl -sL https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - \
&& printf "deb [arch=amd64] https://repo.radeon.com/rocm/apt/$ROCM_VERSION/ jammy main" | tee /etc/apt/sources.list.d/rocm.list \
&& printf "deb [arch=amd64] https://repo.radeon.com/amdgpu/$AMDGPU_VERSION/ubuntu jammy main" | tee /etc/apt/sources.list.d/amdgpu.list \
&& apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
sudo \
libelf1 \
kmod \
file \
python3 \
python3-pip \
rocm-dev \
rocm-libs \
rocm-hip-libraries \
build-essential && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN groupadd -g 109 render
RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-cudart-dev-12-2=${NV_CUDA_CUDART_DEV_VERSION} \
cuda-command-line-tools-12-2=${NV_CUDA_LIB_VERSION} \
cuda-minimal-build-12-2=${NV_CUDA_LIB_VERSION} \
cuda-libraries-dev-12-2=${NV_CUDA_LIB_VERSION} \
cuda-nvml-dev-12-2=${NV_NVML_DEV_VERSION} \
${NV_NVPROF_DEV_PACKAGE} \
${NV_LIBNPP_DEV_PACKAGE} \
libcusparse-dev-12-2=${NV_LIBCUSPARSE_DEV_VERSION} \
${NV_LIBCUBLAS_DEV_PACKAGE} \
${NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE} \
&& rm -rf /var/lib/apt/lists/*
RUN apt install rocm-hip-libraries
# Keep apt from auto upgrading the cublas and nccl packages. See https://gitlab.com/nvidia/container-images/cuda/-/issues/88
RUN apt-mark hold ${NV_LIBCUBLAS_DEV_PACKAGE_NAME}
ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs

View File

@@ -0,0 +1,41 @@
On your host install your Nvidia or AMD gpu drivers.
**HOST Setup**
*Ubuntu 23.04 Nvidia*
```
sudo ubuntu-drivers install
```
Install [docker](https://docs.docker.com/engine/install/ubuntu/) and the post-install to run as a [user](https://docs.docker.com/engine/install/linux-postinstall/)
Install Nvidia [Container and register it](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). In Ubuntu 23.04 systems follow [this](https://github.com/NVIDIA/nvidia-container-toolkit/issues/72#issuecomment-1584574298)
Build docker with :
```
docker build . -f Dockerfile-ubuntu-22.04 -t shark/dev-22.04:5.6 --build-arg=ROCM_VERSION=5.6 --build-arg=AMDGPU_VERSION=5.6 --build-arg=APT_PREF="Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600" --build-arg=IMAGE_NAME=nvidia/cuda --build-arg=TARGETARCH=amd64
```
Run with:
*CPU*
```
docker run -it docker.io/shark/dev-22.04:5.6
```
*Nvidia GPU*
```
docker run --rm -it --gpus all docker.io/shark/dev-22.04:5.6
```
*AMD GPUs*
```
docker run --device /dev/kfd --device /dev/dri docker.io/shark/dev-22.04:5.6
```
More AMD instructions are [here](https://docs.amd.com/en/latest/deploy/docker.html)

View File

@@ -24,7 +24,9 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=100)
width=150,
height=100,
)
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
prompt_data = dict()
@@ -37,7 +39,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_body"):
# TODO: add ability to search image by typing
with gr.Column(scale=1, min_width=600):
image = gr.Image(type="filepath").style(height=512)
image = gr.Image(type="filepath", height=512)
with gr.Column(scale=1, min_width=600):
prompts = gr.Dropdown(

View File

@@ -14,4 +14,4 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79
include = '\.pyi?$'
exclude = "apps/language_models/scripts/vicuna.py"

View File

@@ -17,9 +17,11 @@ parameterized
# Add transformers, diffusers and scipy since it most commonly used
transformers
diffusers
#accelerate is now required for diffusers import from ckpt.
accelerate
scipy
ftfy
gradio==3.34.0
gradio
altair
omegaconf
safetensors
@@ -29,7 +31,13 @@ pytorch_lightning # for runwayml models
tk
pywebview
sentencepiece
py-cpuinfo
tiktoken # for codegen
joblib # for langchain
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@dev

View File

@@ -39,7 +39,7 @@ setup(
install_requires=[
"numpy",
"PyYAML",
"torch-mlir==20230620.875",
"torch-mlir",
]
+ backend_deps,
)

View File

@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir==20230620.875 torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html

View File

@@ -88,7 +88,7 @@ if [ "$torch_mlir_bin" = true ]; then
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
else
$PYTHON -m pip install --pre torch-mlir==20230620.875 -f https://llvm.github.io/torch-mlir/package-index/
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
if [ $? -eq 0 ];then
echo "Successfully Installed torch-mlir"
else

View File

@@ -15,7 +15,7 @@
import torch
from torch._decomp import get_decompositions
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.utils import _stateless
from torch.nn.utils import stateless
from torch import fx
import tempfile

View File

@@ -1,5 +1,5 @@
import torch
from torch.nn.utils import _stateless
from torch.nn.utils import stateless
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark.shark_trainer import SharkTrainer
@@ -33,7 +33,7 @@ inp = (torch.randint(2, (1, 128)),)
def forward(params, buffers, args):
params_and_buffers = {**params, **buffers}
_stateless.functional_call(
stateless.functional_call(
mod, params_and_buffers, args, {}
).sum().backward()
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
@@ -44,5 +44,5 @@ def forward(params, buffers, args):
shark_module = SharkTrainer(mod, inp)
shark_module.compile(forward)
print(shark_module.train())
shark_module.train(num_iters=2)
print("training done")

View File

@@ -43,8 +43,14 @@ def get_iree_device_args(device, extra_args=[]):
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
return get_iree_cpu_args() + data_tiling_flag + u_kernel_flag
return (
get_iree_cpu_args()
+ data_tiling_flag
+ u_kernel_flag
+ stack_size_flag
)
if device_uri[0] == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
@@ -58,9 +64,7 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args
return get_iree_metal_args(
device_num=device_num, extra_args=extra_args
)
return get_iree_metal_args(extra_args=extra_args)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args

View File

@@ -57,15 +57,7 @@ def get_metal_target_triple(device_name):
Returns:
str or None: target triple or None if no match found for given name
"""
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"
else:
triple = None
return triple
return "macos"
def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
@@ -81,7 +73,7 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
triple = get_metal_target_triple(metal_device)
if triple is not None:
print(
f"Found metal device {metal_device}. Using metal target triple {triple}"
f"Found metal device {metal_device}. Using metal target platform {triple}"
)
return f"-iree-metal-target-platform={triple}"
print(
@@ -105,12 +97,12 @@ def get_iree_metal_args(device_num=0, extra_args=[]):
break
if metal_triple_flag is None:
metal_triple_flag = get_metal_triple_flag(
device_num=device_num, extra_args=extra_args
)
metal_triple_flag = get_metal_triple_flag(extra_args=extra_args)
if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
vulkan_target_env = get_vulkan_target_env_flag(
"-iree-vulkan-target-triple=m1-moltenvk-macos"
)
res_metal_flag.append(vulkan_target_env)
return res_metal_flag

View File

@@ -353,7 +353,7 @@ def add_upcast(fx_g):
fx_g.graph.lint()
def transform_fx(fx_g):
def transform_fx(fx_g, quantized=False):
import torch
kwargs_dict = {
@@ -366,6 +366,19 @@ def transform_fx(fx_g):
}
for node in fx_g.graph.nodes:
if node.op == "call_function":
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
if quantized:
continue
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
@@ -427,17 +440,6 @@ def transform_fx(fx_g):
new_node.args = (node,)
new_node.kwargs = {"dtype": torch.float16}
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
# Required for cuda debugging.
# for node in fx_g.graph.nodes:
# if node.op == "call_function":
@@ -486,6 +488,7 @@ def flatten_training_input(inputs):
return tuple(flattened_input)
# TODO: get rid of is_f16 by using precision
# Applies fx conversion to the model and imports the mlir.
def import_with_fx(
model,
@@ -500,10 +503,28 @@ def import_with_fx(
mlir_type="linalg",
is_dynamic=False,
tracing_required=False,
precision="fp32",
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
)
from brevitas_examples.llm.llm_quant.export import (
brevitas_layer_export_mode,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
LinearWeightBlockQuantHandlerFwd,
)
from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
matmul_rhs_group_quant_placeholder,
)
from brevitas.backport.fx.experimental.proxy_tensor import (
make_fx as brevitas_make_fx,
)
golden_values = None
if debug:
@@ -511,26 +532,97 @@ def import_with_fx(
golden_values = model(*inputs)
except:
golden_values = None
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
# TODO: Control the decompositions.
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
]
),
)(*inputs)
decomps_list = [
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
]
if precision in ["int4", "int8"]:
export_context_manager = brevitas_layer_export_mode
export_class = block_quant_layer_level_manager(
export_handlers=[LinearWeightBlockQuantHandlerFwd]
)
with export_context_manager(model, export_class):
fx_g = brevitas_make_fx(
model,
decomposition_table=get_decompositions(decomps_list),
)(*inputs)
transform_fx(fx_g, quantized=True)
replace_call_fn_target(
fx_g,
src=matmul_rhs_group_quant_placeholder,
target=torch.ops.brevitas.matmul_rhs_group_quant,
)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
else:
fx_g = make_fx(
model,
decomposition_table=get_decompositions(decomps_list),
)(*inputs)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()

View File

@@ -25,18 +25,18 @@ google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,True,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,False,True,"",""
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
efficientnet-v2-s,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b0,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
efficientnet_b7,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
gpt2,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
1 resnet50 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False macos
25 microsoft/beit-base-patch16-224-pt22k-ft22k linalg torch 1e-2 1e-3 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/390 macos
26 microsoft/MiniLM-L12-H384-uncased linalg torch 1e-2 1e-3 default None False True False
27 google/mobilebert-uncased linalg torch 1e-2 1e-3 default None False True False https://github.com/nod-ai/SHARK/issues/344 macos
28 mobilenet_v3_small linalg torch 1e-1 1e-2 default nhcw-nhwc False True True True https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487 macos
29 nvidia/mit-b0 linalg torch 1e-2 1e-3 default None True True True https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487 macos
30 resnet101 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col False True False False macos
31 resnet18 linalg torch 1e-2 1e-3 default None True True False macos
32 resnet50 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False macos
33 resnet50_fp16 linalg torch 1e-2 1e-2 default nhcw-nhwc/img2col True False True
34 squeezenet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False macos
35 wide_resnet50_2 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col False True False False macos
36 efficientnet-v2-s stablehlo tf 1e-02 1e-3 default nhcw-nhwc False False False macos
37 mnasnet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True macos
38 efficientnet_b0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True https://github.com/nod-ai/SHARK/issues/1487 macos
39 efficientnet_b7 linalg torch 1e-2 1e-3 default nhcw-nhwc False True True True https://github.com/nod-ai/SHARK/issues/1487 macos
40 efficientnet_b0 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False
41 efficientnet_b7 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False Fails on MacOS builder, VK device lost macos
42 gpt2 stablehlo tf 1e-2 1e-3 default None True False False macos

View File

@@ -0,0 +1,188 @@
import collections
import json
import time
import os
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
from shark_opt_wrapper import OPTForCausalLMModel
MODEL_NAME = "facebook/opt-1.3b"
OPT_MODELNAME = "opt-1.3b"
OPT_FS_NAME = "opt_1-3b"
MAX_SEQUENCE_LENGTH = 512
DEVICE = "cpu"
PROMPTS = [
"What is the meaning of life?",
"Tell me something you don't know.",
"What does Xilinx do?",
"What is the mass of earth?",
"What is a poem?",
"What is recursion?",
"Tell me a one line joke.",
"Who is Gilgamesh?",
"Tell me something about cryptocurrency.",
"How did it all begin?",
]
ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])
def create_vmfb_module(model_name, tokenizer, device):
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
"What is the meaning of life?",
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
return_tensors="pt",
)
inputs = (
encoded_inputs["input_ids"],
encoded_inputs["attention_mask"],
)
# np.save("model_inputs_0.npy", inputs[0])
# np.save("model_inputs_1.npy", inputs[1])
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
if os.path.isfile(mlir_path):
with open(mlir_path, "r") as f:
model_mlir = f.read()
print(f"Loaded .mlir from {mlir_path}")
else:
(model_mlir, func_name) = import_with_fx(
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=OPT_FS_NAME,
return_str=True,
)
with open(mlir_path, "w") as f:
f.write(model_mlir)
print(f"Saved mlir at {mlir_path}")
shark_module = SharkInference(
model_mlir,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=False,
)
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{DEVICE}_tiled_ukernels"
shark_module.save_module(module_name=vmfb_name)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path
def load_shark_model() -> ModelWrapper:
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{DEVICE}_tiled_ukernels.vmfb"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
if not os.path.isfile(vmfb_name):
print(f"vmfb not found. compiling and saving to {vmfb_name}")
create_vmfb_module(OPT_MODELNAME, tokenizer, DEVICE)
shark_module = SharkInference(mlir_module=None, device="cpu-task")
shark_module.load_module(vmfb_name)
return ModelWrapper(model=shark_module, tokenizer=tokenizer)
def run_shark_model(model_wrapper: ModelWrapper, tokens):
# Generate logits output of OPT model.
return model_wrapper.model("forward", tokens)
def run_shark():
model_wrapper = load_shark_model()
prompt = "What is the meaning of life?"
logits = run_shark_model(model_wrapper, prompt)
# Print output logits to validate vs. pytorch + base transformers
print(logits[0])
def load_huggingface_model() -> ModelWrapper:
return ModelWrapper(
model=OPTForCausalLM.from_pretrained(MODEL_NAME),
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME),
)
def run_huggingface_model(model_wrapper: ModelWrapper, tokens):
return model_wrapper.model.forward(
tokens.input_ids, tokens.attention_mask, return_dict=False
)
def run_huggingface():
model_wrapper = load_huggingface_model()
prompt = "What is the meaning of life?"
logits = run_huggingface_model(model_wrapper, prompt)
print(logits[0])
def save_json(data, filename):
with open(filename, "w") as file:
json.dump(data, file)
def collect_huggingface_logits():
t0 = time.time()
model_wrapper = load_huggingface_model()
print("--- Took {} seconds to load Huggingface.".format(time.time() - t0))
results = []
tokenized_prompts = []
for prompt in PROMPTS:
tokens = model_wrapper.tokenizer(
prompt,
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
)
tokenized_prompts.append(tokens)
t0 = time.time()
for idx, tokens in enumerate(tokenized_prompts):
print("prompt: {}".format(PROMPTS[idx]))
logits = run_huggingface_model(model_wrapper, tokens)
results.append([PROMPTS[idx], logits[0].tolist()])
print("--- Took {} seconds to run Huggingface.".format(time.time() - t0))
save_json(results, "/tmp/huggingface.json")
def collect_shark_logits():
t0 = time.time()
model_wrapper = load_shark_model()
print("--- Took {} seconds to load Shark.".format(time.time() - t0))
results = []
tokenized_prompts = []
for prompt in PROMPTS:
tokens = model_wrapper.tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
return_tensors="pt",
)
inputs = (
tokens["input_ids"],
tokens["attention_mask"],
)
tokenized_prompts.append(inputs)
t0 = time.time()
for idx, tokens in enumerate(tokenized_prompts):
print("prompt: {}".format(PROMPTS[idx]))
logits = run_shark_model(model_wrapper, tokens)
lst = [e.tolist() for e in logits]
results.append([PROMPTS[idx], lst])
print("--- Took {} seconds to run Shark.".format(time.time() - t0))
save_json(results, "/tmp/shark.json")
if __name__ == "__main__":
collect_shark_logits()
collect_huggingface_logits()