mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-07 21:13:54 -05:00
Remove SHARK 1.0 implementations (#2042)
Any reimplementation of these features should be tracked in https://github.com/nod-ai/SHARK/issues/1931. These implementations are preserved in the SHARK-1.0 branch: https://github.com/nod-ai/SHARK/tree/SHARK-1.0
This commit is contained in:
@@ -1,16 +0,0 @@
|
||||
## CodeGen Setup using SHARK-server
|
||||
|
||||
### Setup Server
|
||||
- clone SHARK and setup the venv
|
||||
- host the server using `python apps/stable_diffusion/web/index.py --api --server_port=<PORT>`
|
||||
- default server address is `http://0.0.0.0:8080`
|
||||
|
||||
### Setup Client
|
||||
1. fauxpilot-vscode (VSCode Extension):
|
||||
- Code for the extension can be found [here](https://github.com/Venthe/vscode-fauxpilot)
|
||||
- PreReq: VSCode extension (will need [`nodejs` and `npm`](https://nodejs.org/en/download) to compile and run the extension)
|
||||
- Compile and Run the extension on VSCode (press F5 on VSCode), this opens a new VSCode window with the extension running
|
||||
- Open VSCode settings, search for fauxpilot in settings and modify `server : http://<IP>:<PORT>`, `Model : codegen` , `Max Lines : 30`
|
||||
|
||||
2. Others (REST API curl, OpenAI Python bindings) as shown [here](https://github.com/fauxpilot/fauxpilot/blob/main/documentation/client.md)
|
||||
- using Github Copilot VSCode extension with SHARK-server needs more work to be functional.
|
||||
@@ -1,18 +0,0 @@
|
||||
# Langchain
|
||||
|
||||
## How to run the model
|
||||
|
||||
1.) Install all the dependencies by running:
|
||||
```shell
|
||||
pip install -r apps/language_models/langchain/langchain_requirements.txt
|
||||
sudo apt-get install -y libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
||||
```
|
||||
|
||||
2.) Create a folder named `user_path` in `apps/language_models/langchain/` directory.
|
||||
|
||||
Now, you are ready to use the model.
|
||||
|
||||
3.) To run the model, run the following command:
|
||||
```shell
|
||||
python apps/language_models/langchain/gen.py --cli=True
|
||||
```
|
||||
@@ -1,186 +0,0 @@
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from evaluate_params import eval_func_param_names
|
||||
from gen import Langchain
|
||||
from prompter import non_hf_types
|
||||
from utils import clear_torch_cache, NullContext, get_kwargs
|
||||
|
||||
|
||||
def run_cli( # for local function:
|
||||
base_model=None,
|
||||
lora_weights=None,
|
||||
inference_server=None,
|
||||
debug=None,
|
||||
chat_context=None,
|
||||
examples=None,
|
||||
memory_restriction_level=None,
|
||||
# for get_model:
|
||||
score_model=None,
|
||||
load_8bit=None,
|
||||
load_4bit=None,
|
||||
load_half=None,
|
||||
load_gptq=None,
|
||||
use_safetensors=None,
|
||||
infer_devices=None,
|
||||
tokenizer_base_model=None,
|
||||
gpu_id=None,
|
||||
local_files_only=None,
|
||||
resume_download=None,
|
||||
use_auth_token=None,
|
||||
trust_remote_code=None,
|
||||
offload_folder=None,
|
||||
compile_model=None,
|
||||
# for some evaluate args
|
||||
stream_output=None,
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
num_beams=None,
|
||||
max_new_tokens=None,
|
||||
min_new_tokens=None,
|
||||
early_stopping=None,
|
||||
max_time=None,
|
||||
repetition_penalty=None,
|
||||
num_return_sequences=None,
|
||||
do_sample=None,
|
||||
chat=None,
|
||||
langchain_mode=None,
|
||||
langchain_action=None,
|
||||
document_choice=None,
|
||||
top_k_docs=None,
|
||||
chunk=None,
|
||||
chunk_size=None,
|
||||
# for evaluate kwargs
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
concurrency_count=None,
|
||||
save_dir=None,
|
||||
sanitize_bot_response=None,
|
||||
model_state0=None,
|
||||
max_max_new_tokens=None,
|
||||
is_public=None,
|
||||
max_max_time=None,
|
||||
raise_generate_gpu_exceptions=None,
|
||||
load_db_if_exists=None,
|
||||
dbs=None,
|
||||
user_path=None,
|
||||
detect_user_path_changes_every_query=None,
|
||||
use_openai_embedding=None,
|
||||
use_openai_model=None,
|
||||
hf_embedding_model=None,
|
||||
db_type=None,
|
||||
n_jobs=None,
|
||||
first_para=None,
|
||||
text_limit=None,
|
||||
verbose=None,
|
||||
cli=None,
|
||||
reverse_docs=None,
|
||||
use_cache=None,
|
||||
auto_reduce_chunks=None,
|
||||
max_chunks=None,
|
||||
model_lock=None,
|
||||
force_langchain_evaluate=None,
|
||||
model_state_none=None,
|
||||
# unique to this function:
|
||||
cli_loop=None,
|
||||
):
|
||||
Langchain.check_locals(**locals())
|
||||
|
||||
score_model = "" # FIXME: For now, so user doesn't have to pass
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
||||
device = "cpu" if n_gpus == 0 else "cuda"
|
||||
context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
|
||||
|
||||
with context_class(device):
|
||||
from functools import partial
|
||||
|
||||
# get score model
|
||||
smodel, stokenizer, sdevice = Langchain.get_score_model(
|
||||
reward_type=True,
|
||||
**get_kwargs(
|
||||
Langchain.get_score_model,
|
||||
exclude_names=["reward_type"],
|
||||
**locals()
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer, device = Langchain.get_model(
|
||||
reward_type=False,
|
||||
**get_kwargs(
|
||||
Langchain.get_model, exclude_names=["reward_type"], **locals()
|
||||
)
|
||||
)
|
||||
model_dict = dict(
|
||||
base_model=base_model,
|
||||
tokenizer_base_model=tokenizer_base_model,
|
||||
lora_weights=lora_weights,
|
||||
inference_server=inference_server,
|
||||
prompt_type=prompt_type,
|
||||
prompt_dict=prompt_dict,
|
||||
)
|
||||
model_state = dict(model=model, tokenizer=tokenizer, device=device)
|
||||
model_state.update(model_dict)
|
||||
my_db_state = [None]
|
||||
fun = partial(
|
||||
Langchain.evaluate,
|
||||
model_state,
|
||||
my_db_state,
|
||||
**get_kwargs(
|
||||
Langchain.evaluate,
|
||||
exclude_names=["model_state", "my_db_state"]
|
||||
+ eval_func_param_names,
|
||||
**locals()
|
||||
)
|
||||
)
|
||||
|
||||
example1 = examples[-1] # pick reference example
|
||||
all_generations = []
|
||||
while True:
|
||||
clear_torch_cache()
|
||||
instruction = input("\nEnter an instruction: ")
|
||||
if instruction == "exit":
|
||||
break
|
||||
|
||||
eval_vars = copy.deepcopy(example1)
|
||||
eval_vars[eval_func_param_names.index("instruction")] = eval_vars[
|
||||
eval_func_param_names.index("instruction_nochat")
|
||||
] = instruction
|
||||
eval_vars[eval_func_param_names.index("iinput")] = eval_vars[
|
||||
eval_func_param_names.index("iinput_nochat")
|
||||
] = "" # no input yet
|
||||
eval_vars[
|
||||
eval_func_param_names.index("context")
|
||||
] = "" # no context yet
|
||||
|
||||
# grab other parameters, like langchain_mode
|
||||
for k in eval_func_param_names:
|
||||
if k in locals():
|
||||
eval_vars[eval_func_param_names.index(k)] = locals()[k]
|
||||
|
||||
gener = fun(*tuple(eval_vars))
|
||||
outr = ""
|
||||
res_old = ""
|
||||
for gen_output in gener:
|
||||
res = gen_output["response"]
|
||||
extra = gen_output["sources"]
|
||||
if base_model not in non_hf_types or base_model in ["llama"]:
|
||||
if not stream_output:
|
||||
print(res)
|
||||
else:
|
||||
# then stream output for gradio that has full output each generation, so need here to show only new chars
|
||||
diff = res[len(res_old) :]
|
||||
print(diff, end="", flush=True)
|
||||
res_old = res
|
||||
outr = res # don't accumulate
|
||||
else:
|
||||
outr += res # just is one thing
|
||||
if extra:
|
||||
# show sources at end after model itself had streamed to std rest of response
|
||||
print(extra, flush=True)
|
||||
all_generations.append(outr + "\n")
|
||||
if not cli_loop:
|
||||
break
|
||||
return all_generations
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,103 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PromptType(Enum):
|
||||
custom = -1
|
||||
plain = 0
|
||||
instruct = 1
|
||||
quality = 2
|
||||
human_bot = 3
|
||||
dai_faq = 4
|
||||
summarize = 5
|
||||
simple_instruct = 6
|
||||
instruct_vicuna = 7
|
||||
instruct_with_end = 8
|
||||
human_bot_orig = 9
|
||||
prompt_answer = 10
|
||||
open_assistant = 11
|
||||
wizard_lm = 12
|
||||
wizard_mega = 13
|
||||
instruct_vicuna2 = 14
|
||||
instruct_vicuna3 = 15
|
||||
wizard2 = 16
|
||||
wizard3 = 17
|
||||
instruct_simple = 18
|
||||
wizard_vicuna = 19
|
||||
openai = 20
|
||||
openai_chat = 21
|
||||
gptj = 22
|
||||
prompt_answer_openllama = 23
|
||||
vicuna11 = 24
|
||||
mptinstruct = 25
|
||||
mptchat = 26
|
||||
falcon = 27
|
||||
|
||||
|
||||
class DocumentChoices(Enum):
|
||||
All_Relevant = 0
|
||||
All_Relevant_Only_Sources = 1
|
||||
Only_All_Sources = 2
|
||||
Just_LLM = 3
|
||||
|
||||
|
||||
non_query_commands = [
|
||||
DocumentChoices.All_Relevant_Only_Sources.name,
|
||||
DocumentChoices.Only_All_Sources.name,
|
||||
]
|
||||
|
||||
|
||||
class LangChainMode(Enum):
|
||||
"""LangChain mode"""
|
||||
|
||||
DISABLED = "Disabled"
|
||||
CHAT_LLM = "ChatLLM"
|
||||
LLM = "LLM"
|
||||
ALL = "All"
|
||||
WIKI = "wiki"
|
||||
WIKI_FULL = "wiki_full"
|
||||
USER_DATA = "UserData"
|
||||
MY_DATA = "MyData"
|
||||
GITHUB_H2OGPT = "github h2oGPT"
|
||||
H2O_DAI_DOCS = "DriverlessAI docs"
|
||||
|
||||
|
||||
class LangChainAction(Enum):
|
||||
"""LangChain action"""
|
||||
|
||||
QUERY = "Query"
|
||||
# WIP:
|
||||
# SUMMARIZE_MAP = "Summarize_map_reduce"
|
||||
SUMMARIZE_MAP = "Summarize"
|
||||
SUMMARIZE_ALL = "Summarize_all"
|
||||
SUMMARIZE_REFINE = "Summarize_refine"
|
||||
|
||||
|
||||
no_server_str = no_lora_str = no_model_str = "[None/Remove]"
|
||||
|
||||
# from site-packages/langchain/llms/openai.py
|
||||
# but needed since ChatOpenAI doesn't have this information
|
||||
model_token_mapping = {
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-0314": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0314": 32768,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-16k": 16 * 1024,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
"text-ada-001": 2049,
|
||||
"ada": 2049,
|
||||
"text-babbage-001": 2040,
|
||||
"babbage": 2049,
|
||||
"text-curie-001": 2049,
|
||||
"curie": 2049,
|
||||
"davinci": 2049,
|
||||
"text-davinci-003": 4097,
|
||||
"text-davinci-002": 4097,
|
||||
"code-davinci-002": 8001,
|
||||
"code-davinci-001": 8001,
|
||||
"code-cushman-002": 2048,
|
||||
"code-cushman-001": 2048,
|
||||
}
|
||||
|
||||
source_prefix = "Sources [Score | Link]:"
|
||||
source_postfix = "End Sources<p>"
|
||||
@@ -1,53 +0,0 @@
|
||||
no_default_param_names = [
|
||||
"instruction",
|
||||
"iinput",
|
||||
"context",
|
||||
"instruction_nochat",
|
||||
"iinput_nochat",
|
||||
]
|
||||
|
||||
gen_hyper = [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"num_beams",
|
||||
"max_new_tokens",
|
||||
"min_new_tokens",
|
||||
"early_stopping",
|
||||
"max_time",
|
||||
"repetition_penalty",
|
||||
"num_return_sequences",
|
||||
"do_sample",
|
||||
]
|
||||
|
||||
eval_func_param_names = (
|
||||
[
|
||||
"instruction",
|
||||
"iinput",
|
||||
"context",
|
||||
"stream_output",
|
||||
"prompt_type",
|
||||
"prompt_dict",
|
||||
]
|
||||
+ gen_hyper
|
||||
+ [
|
||||
"chat",
|
||||
"instruction_nochat",
|
||||
"iinput_nochat",
|
||||
"langchain_mode",
|
||||
"langchain_action",
|
||||
"top_k_docs",
|
||||
"chunk",
|
||||
"chunk_size",
|
||||
"document_choice",
|
||||
]
|
||||
)
|
||||
|
||||
# form evaluate defaults for submit_nochat_api
|
||||
eval_func_param_names_defaults = eval_func_param_names.copy()
|
||||
for k in no_default_param_names:
|
||||
if k in eval_func_param_names_defaults:
|
||||
eval_func_param_names_defaults.remove(k)
|
||||
|
||||
|
||||
eval_extra_columns = ["prompt", "response", "score"]
|
||||
@@ -1,846 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from typing import (
|
||||
Any,
|
||||
Mapping,
|
||||
Optional,
|
||||
Dict,
|
||||
List,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
Protocol,
|
||||
)
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from abc import ABC, abstractmethod
|
||||
import langchain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.question_answering import stuff_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
from pydantic import Extra, Field, root_validator, validator
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
"""Format a document into a string based on a prompt template."""
|
||||
base_info = {"page_content": doc.page_content}
|
||||
base_info.update(doc.metadata)
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
iv for iv in prompt.input_variables if iv != "page_content"
|
||||
]
|
||||
raise ValueError(
|
||||
f"Document prompt requires documents to have metadata variables: "
|
||||
f"{required_metadata}. Received document with missing metadata: "
|
||||
f"{list(missing_metadata)}."
|
||||
)
|
||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||
return prompt.format(**document_info)
|
||||
|
||||
|
||||
class Chain(Serializable, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default_factory=_get_verbosity
|
||||
) # Whether to print the response text
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys this chain expects."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys this chain expects."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
missing_keys = set(self.output_keys).difference(outputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some output keys: {missing_keys}")
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
input_docs = inputs["input_documents"]
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
|
||||
if "is_first" in inputs.keys() and not inputs["is_first"]:
|
||||
run_manager_ = run_manager
|
||||
input_list = [inputs]
|
||||
stop = None
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {
|
||||
k: inputs[k] for k in self.prompt.input_variables
|
||||
}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager_:
|
||||
run_manager_.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
prompts = prompt_strings
|
||||
callbacks = run_manager_.get_child() if run_manager_ else None
|
||||
tags = None
|
||||
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# If string is passed in directly no errors will be raised but outputs will
|
||||
# not make sense.
|
||||
if not isinstance(prompts, list):
|
||||
raise ValueError(
|
||||
"Argument 'prompts' is expected to be of type List[str], received"
|
||||
f" argument of type {type(prompts)}."
|
||||
)
|
||||
params = self.llm.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
disregard_cache = self.llm.cache is not None and not self.llm.cache
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.llm.callbacks,
|
||||
self.llm.verbose,
|
||||
tags,
|
||||
self.llm.tags,
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.llm.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager_ = callback_manager.on_llm_start(
|
||||
dumpd(self),
|
||||
prompts,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
)
|
||||
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
inputs_ = prompt
|
||||
num_workers = None
|
||||
batch_size = None
|
||||
|
||||
if num_workers is None:
|
||||
if self.llm.pipeline._num_workers is None:
|
||||
num_workers = 0
|
||||
else:
|
||||
num_workers = self.llm.pipeline._num_workers
|
||||
if batch_size is None:
|
||||
if self.llm.pipeline._batch_size is None:
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = self.llm.pipeline._batch_size
|
||||
|
||||
preprocess_params = {}
|
||||
generate_kwargs = {}
|
||||
preprocess_params.update(generate_kwargs)
|
||||
forward_params = generate_kwargs
|
||||
postprocess_params = {}
|
||||
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
|
||||
preprocess_params = {
|
||||
**self.llm.pipeline._preprocess_params,
|
||||
**preprocess_params,
|
||||
}
|
||||
forward_params = {
|
||||
**self.llm.pipeline._forward_params,
|
||||
**forward_params,
|
||||
}
|
||||
postprocess_params = {
|
||||
**self.llm.pipeline._postprocess_params,
|
||||
**postprocess_params,
|
||||
}
|
||||
|
||||
self.llm.pipeline.call_count += 1
|
||||
if (
|
||||
self.llm.pipeline.call_count > 10
|
||||
and self.llm.pipeline.framework == "pt"
|
||||
and self.llm.pipeline.device.type == "cuda"
|
||||
):
|
||||
warnings.warn(
|
||||
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
|
||||
" dataset",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
model_inputs = self.llm.pipeline.preprocess(
|
||||
inputs_, **preprocess_params
|
||||
)
|
||||
model_outputs = self.llm.pipeline.forward(
|
||||
model_inputs, **forward_params
|
||||
)
|
||||
model_outputs["process"] = False
|
||||
return model_outputs
|
||||
output = LLMResult(generations=generations)
|
||||
run_manager_.on_llm_end(output)
|
||||
if run_manager_:
|
||||
output.run = RunInfo(run_id=run_manager_.run_id)
|
||||
response = output
|
||||
|
||||
outputs = [
|
||||
# Get the text of the top generated string.
|
||||
{self.output_key: generation[0].text}
|
||||
for generation in response.generations
|
||||
][0]
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
else:
|
||||
_run_manager = (
|
||||
run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
)
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {
|
||||
k: v for k, v in inputs.items() if k != self.input_key
|
||||
}
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in other_keys.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
inputs["is_first"] = False
|
||||
inputs["input_documents"] = input_docs
|
||||
|
||||
# Call predict on the LLM.
|
||||
output = self.llm_chain(inputs, callbacks=_run_manager.get_child())
|
||||
if "process" in output.keys() and not output["process"]:
|
||||
return output
|
||||
output = output[self.llm_chain.output_key]
|
||||
extra_return_dict = {}
|
||||
extra_return_dict[self.output_key] = output
|
||||
outputs = extra_return_dict
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep outputs."""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(
|
||||
self, inputs: Union[Dict[str, Any], Any]
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(
|
||||
self.memory.memory_variables
|
||||
)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError(
|
||||
"`run` supports only one positional argument."
|
||||
)
|
||||
return self(args[0], callbacks=callbacks, tags=tags)[
|
||||
self.output_keys[0]
|
||||
]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs, callbacks=callbacks, tags=tags)[
|
||||
self.output_keys[0]
|
||||
]
|
||||
|
||||
if not kwargs and not args:
|
||||
raise ValueError(
|
||||
"`run` supported with either positional arguments or keyword arguments,"
|
||||
" but none were provided."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of chain."""
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
_dict = super().dict()
|
||||
_dict["_type"] = self._chain_type
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the chain.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the chain to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain.save(file_path="path/chain.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
chain_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(chain_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
input_key: str = "input_documents" #: :meta private:
|
||||
output_key: str = "output_text" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def prompt_length(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Optional[int]:
|
||||
"""Return the prompt length given the documents passed in.
|
||||
|
||||
Returns None if the method does not depend on the prompt length.
|
||||
"""
|
||||
return None
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = (
|
||||
run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
)
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in other_keys.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
|
||||
# Call predict on the LLM.
|
||||
output, extra_return_dict = (
|
||||
self.llm_chain(inputs, callbacks=_run_manager.get_child())[
|
||||
self.llm_chain.output_key
|
||||
],
|
||||
{},
|
||||
)
|
||||
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Generation(Serializable):
|
||||
"""Output of a single generation."""
|
||||
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
|
||||
generation_info: Optional[Dict[str, Any]] = None
|
||||
"""Raw generation info response from the provider"""
|
||||
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
||||
# TODO: add log probs
|
||||
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
"""Chain to run queries against LLMs.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import LLMChain, OpenAI, PromptTemplate
|
||||
prompt_template = "Tell me a {adjective} joke"
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["adjective"], template=prompt_template
|
||||
)
|
||||
llm = LLMChain(llm=OpenAI(), prompt=prompt)
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
prompt: BasePromptTemplate
|
||||
"""Prompt object to use."""
|
||||
llm: BaseLanguageModel
|
||||
output_key: str = "text" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
prompts, stop = self.prep_prompts([inputs], run_manager=run_manager)
|
||||
response = self.llm.generate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def prep_prompts(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
if "stop" in input_list[0]:
|
||||
stop = input_list[0]["stop"]
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {
|
||||
k: inputs[k] for k in self.prompt.input_variables
|
||||
}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager:
|
||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
response = self.generate(input_list, run_manager=run_manager)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
|
||||
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
||||
"""Create outputs from response."""
|
||||
return [
|
||||
# Get the text of the top generated string.
|
||||
{self.output_key: generation[0].text}
|
||||
for generation in response.generations
|
||||
]
|
||||
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, Any]]:
|
||||
"""Call predict and then parse the results."""
|
||||
result = self.predict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
def apply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = self.apply(input_list, callbacks=callbacks)
|
||||
return self._parse_result(result)
|
||||
|
||||
def _parse_result(
|
||||
self, result: List[Dict[str, str]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
if self.prompt.output_parser is not None:
|
||||
return [
|
||||
self.prompt.output_parser.parse(res[self.output_key])
|
||||
for res in result
|
||||
]
|
||||
else:
|
||||
return result
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_chain"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
|
||||
"""Create LLMChain from LLM and template."""
|
||||
prompt_template = PromptTemplate.from_template(template)
|
||||
return cls(llm=llm, prompt=prompt_template)
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
return PromptTemplate(
|
||||
input_variables=["page_content"], template="{page_content}"
|
||||
)
|
||||
|
||||
|
||||
class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM wrapper to use after formatting documents."""
|
||||
document_prompt: BasePromptTemplate = Field(
|
||||
default_factory=_get_default_document_prompt
|
||||
)
|
||||
"""Prompt to use to format each document."""
|
||||
document_variable_name: str
|
||||
"""The variable name in the llm_chain to put the documents in.
|
||||
If only one variable in the llm_chain, this need not be provided."""
|
||||
document_separator: str = "\n\n"
|
||||
"""The string with which to join the formatted documents"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_default_document_variable_name(cls, values: Dict) -> Dict:
|
||||
"""Get default document variable name, if not provided."""
|
||||
llm_chain_variables = values["llm_chain"].prompt.input_variables
|
||||
if "document_variable_name" not in values:
|
||||
if len(llm_chain_variables) == 1:
|
||||
values["document_variable_name"] = llm_chain_variables[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"document_variable_name must be provided if there are "
|
||||
"multiple llm_chain_variables"
|
||||
)
|
||||
else:
|
||||
if values["document_variable_name"] not in llm_chain_variables:
|
||||
raise ValueError(
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
||||
# Format each document according to the prompt
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
return inputs
|
||||
|
||||
def prompt_length(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Optional[int]:
|
||||
"""Get the prompt length by formatting the prompt."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "stuff_documents_chain"
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(
|
||||
self, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
document_variable_name: str = "context",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# TODO: document prompt
|
||||
return StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_qa_chain(
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering chain.
|
||||
|
||||
Args:
|
||||
llm: Language Model to use in the chain.
|
||||
chain_type: Type of document combining chain to use. Should be one of "stuff",
|
||||
"map_reduce", "map_rerank", and "refine".
|
||||
verbose: Whether chains should be run in verbose mode or not. Note that this
|
||||
applies to all chains that make up the final chain.
|
||||
callback_manager: Callback manager to use for the chain.
|
||||
|
||||
Returns:
|
||||
A chain to use for question answering.
|
||||
"""
|
||||
loader_mapping: Mapping[str, LoadingCallable] = {
|
||||
"stuff": _load_stuff_chain,
|
||||
}
|
||||
if chain_type not in loader_mapping:
|
||||
raise ValueError(
|
||||
f"Got unsupported chain type: {chain_type}. "
|
||||
f"Should be one of {loader_mapping.keys()}"
|
||||
)
|
||||
return loader_mapping[chain_type](
|
||||
llm, verbose=verbose, callback_manager=callback_manager, **kwargs
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,380 +0,0 @@
|
||||
import inspect
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Dict, Any, Optional, List
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from pydantic import root_validator
|
||||
from langchain.llms import gpt4all
|
||||
from dotenv import dotenv_values
|
||||
|
||||
from utils import FakeTokenizer
|
||||
|
||||
|
||||
def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
||||
# defaults (some of these are generation parameters, so need to be passed in at generation time)
|
||||
model_kwargs = dict(
|
||||
n_threads=os.cpu_count() // 2,
|
||||
temp=kwargs.get("temperature", 0.2),
|
||||
top_p=kwargs.get("top_p", 0.75),
|
||||
top_k=kwargs.get("top_k", 40),
|
||||
n_ctx=2048 - 256,
|
||||
)
|
||||
env_gpt4all_file = ".env_gpt4all"
|
||||
model_kwargs.update(dotenv_values(env_gpt4all_file))
|
||||
# make int or float if can to satisfy types for class
|
||||
for k, v in model_kwargs.items():
|
||||
try:
|
||||
if float(v) == int(v):
|
||||
model_kwargs[k] = int(v)
|
||||
else:
|
||||
model_kwargs[k] = float(v)
|
||||
except:
|
||||
pass
|
||||
|
||||
if base_model == "llama":
|
||||
if "model_path_llama" not in model_kwargs:
|
||||
raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
|
||||
model_path = model_kwargs.pop("model_path_llama")
|
||||
# FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# llama sets some things at init model time, not generation time
|
||||
func_names = list(inspect.signature(Llama.__init__).parameters)
|
||||
model_kwargs = {
|
||||
k: v for k, v in model_kwargs.items() if k in func_names
|
||||
}
|
||||
model_kwargs["n_ctx"] = int(model_kwargs["n_ctx"])
|
||||
model = Llama(model_path=model_path, **model_kwargs)
|
||||
elif base_model in "gpt4all_llama":
|
||||
if (
|
||||
"model_name_gpt4all_llama" not in model_kwargs
|
||||
and "model_path_gpt4all_llama" not in model_kwargs
|
||||
):
|
||||
raise ValueError(
|
||||
"No model_name_gpt4all_llama or model_path_gpt4all_llama in %s"
|
||||
% env_gpt4all_file
|
||||
)
|
||||
model_name = model_kwargs.pop("model_name_gpt4all_llama")
|
||||
model_type = "llama"
|
||||
from gpt4all import GPT4All as GPT4AllModel
|
||||
|
||||
model = GPT4AllModel(model_name=model_name, model_type=model_type)
|
||||
elif base_model in "gptj":
|
||||
if (
|
||||
"model_name_gptj" not in model_kwargs
|
||||
and "model_path_gptj" not in model_kwargs
|
||||
):
|
||||
raise ValueError(
|
||||
"No model_name_gpt4j or model_path_gpt4j in %s"
|
||||
% env_gpt4all_file
|
||||
)
|
||||
model_name = model_kwargs.pop("model_name_gptj")
|
||||
model_type = "gptj"
|
||||
from gpt4all import GPT4All as GPT4AllModel
|
||||
|
||||
model = GPT4AllModel(model_name=model_name, model_type=model_type)
|
||||
else:
|
||||
raise ValueError("No such base_model %s" % base_model)
|
||||
return model, FakeTokenizer(), "cpu"
|
||||
|
||||
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
|
||||
class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
# streaming to std already occurs without this
|
||||
# sys.stdout.write(token)
|
||||
# sys.stdout.flush()
|
||||
pass
|
||||
|
||||
|
||||
def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
|
||||
# default from class
|
||||
model_kwargs = {
|
||||
k: v.default
|
||||
for k, v in dict(inspect.signature(cls).parameters).items()
|
||||
if k not in exclude_list
|
||||
}
|
||||
# from our defaults
|
||||
model_kwargs.update(default_kwargs)
|
||||
# from user defaults
|
||||
model_kwargs.update(env_kwargs)
|
||||
# ensure only valid keys
|
||||
func_names = list(inspect.signature(cls).parameters)
|
||||
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def get_llm_gpt4all(
|
||||
model_name,
|
||||
model=None,
|
||||
max_new_tokens=256,
|
||||
temperature=0.1,
|
||||
repetition_penalty=1.0,
|
||||
top_k=40,
|
||||
top_p=0.7,
|
||||
streaming=False,
|
||||
callbacks=None,
|
||||
prompter=None,
|
||||
verbose=False,
|
||||
):
|
||||
assert prompter is not None
|
||||
env_gpt4all_file = ".env_gpt4all"
|
||||
env_kwargs = dotenv_values(env_gpt4all_file)
|
||||
n_ctx = env_kwargs.pop("n_ctx", 2048 - max_new_tokens)
|
||||
default_kwargs = dict(
|
||||
context_erase=0.5,
|
||||
n_batch=1,
|
||||
n_ctx=n_ctx,
|
||||
n_predict=max_new_tokens,
|
||||
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
|
||||
repeat_penalty=repetition_penalty,
|
||||
temp=temperature,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
use_mlock=True,
|
||||
verbose=verbose,
|
||||
)
|
||||
if model_name == "llama":
|
||||
cls = H2OLlamaCpp
|
||||
model_path = (
|
||||
env_kwargs.pop("model_path_llama") if model is None else model
|
||||
)
|
||||
model_kwargs = get_model_kwargs(
|
||||
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
|
||||
)
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
model_path=model_path,
|
||||
callbacks=callbacks,
|
||||
streaming=streaming,
|
||||
prompter=prompter,
|
||||
)
|
||||
)
|
||||
llm = cls(**model_kwargs)
|
||||
llm.client.verbose = verbose
|
||||
elif model_name == "gpt4all_llama":
|
||||
cls = H2OGPT4All
|
||||
model_path = (
|
||||
env_kwargs.pop("model_path_gpt4all_llama")
|
||||
if model is None
|
||||
else model
|
||||
)
|
||||
model_kwargs = get_model_kwargs(
|
||||
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
|
||||
)
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
model=model_path,
|
||||
backend="llama",
|
||||
callbacks=callbacks,
|
||||
streaming=streaming,
|
||||
prompter=prompter,
|
||||
)
|
||||
)
|
||||
llm = cls(**model_kwargs)
|
||||
elif model_name == "gptj":
|
||||
cls = H2OGPT4All
|
||||
model_path = (
|
||||
env_kwargs.pop("model_path_gptj") if model is None else model
|
||||
)
|
||||
model_kwargs = get_model_kwargs(
|
||||
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
|
||||
)
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
model=model_path,
|
||||
backend="gptj",
|
||||
callbacks=callbacks,
|
||||
streaming=streaming,
|
||||
prompter=prompter,
|
||||
)
|
||||
)
|
||||
llm = cls(**model_kwargs)
|
||||
else:
|
||||
raise RuntimeError("No such model_name %s" % model_name)
|
||||
return llm
|
||||
|
||||
|
||||
class H2OGPT4All(gpt4all.GPT4All):
|
||||
model: Any
|
||||
prompter: Any
|
||||
"""Path to the pre-trained GPT4All model file."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in the environment."""
|
||||
try:
|
||||
if isinstance(values["model"], str):
|
||||
from gpt4all import GPT4All as GPT4AllModel
|
||||
|
||||
full_path = values["model"]
|
||||
model_path, delimiter, model_name = full_path.rpartition("/")
|
||||
model_path += delimiter
|
||||
|
||||
values["client"] = GPT4AllModel(
|
||||
model_name=model_name,
|
||||
model_path=model_path or None,
|
||||
model_type=values["backend"],
|
||||
allow_download=False,
|
||||
)
|
||||
if values["n_threads"] is not None:
|
||||
# set n_threads
|
||||
values["client"].model.set_thread_count(
|
||||
values["n_threads"]
|
||||
)
|
||||
else:
|
||||
values["client"] = values["model"]
|
||||
try:
|
||||
values["backend"] = values["client"].model_type
|
||||
except AttributeError:
|
||||
# The below is for compatibility with GPT4All Python bindings <= 0.2.3.
|
||||
values["backend"] = values["client"].model.model_type
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import gpt4all python package. "
|
||||
"Please install it with `pip install gpt4all`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
# Roughly 4 chars per token if natural language
|
||||
prompt = prompt[-self.n_ctx * 4 :]
|
||||
|
||||
# use instruct prompting
|
||||
data_point = dict(context="", instruction=prompt, input="")
|
||||
prompt = self.prompter.generate_prompt(data_point)
|
||||
|
||||
verbose = False
|
||||
if verbose:
|
||||
print("_call prompt: %s" % prompt, flush=True)
|
||||
# FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
|
||||
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
||||
|
||||
|
||||
from langchain.llms import LlamaCpp
|
||||
|
||||
|
||||
class H2OLlamaCpp(LlamaCpp):
|
||||
model_path: Any
|
||||
prompter: Any
|
||||
"""Path to the pre-trained GPT4All model file."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that llama-cpp-python library is installed."""
|
||||
if isinstance(values["model_path"], str):
|
||||
model_path = values["model_path"]
|
||||
model_param_names = [
|
||||
"lora_path",
|
||||
"lora_base",
|
||||
"n_ctx",
|
||||
"n_parts",
|
||||
"seed",
|
||||
"f16_kv",
|
||||
"logits_all",
|
||||
"vocab_only",
|
||||
"use_mlock",
|
||||
"n_threads",
|
||||
"n_batch",
|
||||
"use_mmap",
|
||||
"last_n_tokens_size",
|
||||
]
|
||||
model_params = {k: values[k] for k in model_param_names}
|
||||
# For backwards compatibility, only include if non-null.
|
||||
if values["n_gpu_layers"] is not None:
|
||||
model_params["n_gpu_layers"] = values["n_gpu_layers"]
|
||||
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
|
||||
values["client"] = Llama(model_path, **model_params)
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import llama-cpp-python library. "
|
||||
"Please install the llama-cpp-python library to "
|
||||
"use this embedding model: pip install llama-cpp-python"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not load Llama model from path: {model_path}. "
|
||||
f"Received error {e}"
|
||||
)
|
||||
else:
|
||||
values["client"] = values["model_path"]
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
verbose = False
|
||||
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
|
||||
# still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
|
||||
prompt = prompt[-self.n_ctx * 4 :]
|
||||
prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
||||
num_prompt_tokens = len(prompt_tokens)
|
||||
if num_prompt_tokens > self.n_ctx:
|
||||
# conservative by using int()
|
||||
chars_per_token = int(len(prompt) / num_prompt_tokens)
|
||||
prompt = prompt[-self.n_ctx * chars_per_token :]
|
||||
if verbose:
|
||||
print(
|
||||
"reducing tokens, assuming average of %s chars/token: %s"
|
||||
% chars_per_token,
|
||||
flush=True,
|
||||
)
|
||||
prompt_tokens2 = self.client.tokenize(
|
||||
b" " + prompt.encode("utf-8")
|
||||
)
|
||||
num_prompt_tokens2 = len(prompt_tokens2)
|
||||
print(
|
||||
"reduced tokens from %d -> %d"
|
||||
% (num_prompt_tokens, num_prompt_tokens2),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# use instruct prompting
|
||||
data_point = dict(context="", instruction=prompt, input="")
|
||||
prompt = self.prompter.generate_prompt(data_point)
|
||||
|
||||
if verbose:
|
||||
print("_call prompt: %s" % prompt, flush=True)
|
||||
|
||||
if self.streaming:
|
||||
text_callback = None
|
||||
if run_manager:
|
||||
text_callback = partial(
|
||||
run_manager.on_llm_new_token, verbose=self.verbose
|
||||
)
|
||||
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
|
||||
if text_callback:
|
||||
text_callback(prompt)
|
||||
text = ""
|
||||
for token in self.stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager
|
||||
):
|
||||
text_chunk = token["choices"][0]["text"]
|
||||
# self.stream already calls text_callback
|
||||
# if text_callback:
|
||||
# text_callback(text_chunk)
|
||||
text += text_chunk
|
||||
return text
|
||||
else:
|
||||
params = self._get_parameters(stop)
|
||||
params = {**params, **kwargs}
|
||||
result = self.client(prompt=prompt, **params)
|
||||
return result["choices"][0]["text"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,93 +0,0 @@
|
||||
import traceback
|
||||
from typing import Callable
|
||||
import os
|
||||
|
||||
from gradio_client.client import Job
|
||||
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
from gradio_client import Client
|
||||
|
||||
|
||||
class GradioClient(Client):
|
||||
"""
|
||||
Parent class of gradio client
|
||||
To handle automatically refreshing client if detect gradio server changed
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
super().__init__(*args, **kwargs)
|
||||
self.server_hash = self.get_server_hash()
|
||||
|
||||
def get_server_hash(self):
|
||||
"""
|
||||
Get server hash using super without any refresh action triggered
|
||||
Returns: git hash of gradio server
|
||||
"""
|
||||
return super().submit(api_name="/system_hash").result()
|
||||
|
||||
def refresh_client_if_should(self):
|
||||
# get current hash in order to update api_name -> fn_index map in case gradio server changed
|
||||
# FIXME: Could add cli api as hash
|
||||
server_hash = self.get_server_hash()
|
||||
if self.server_hash != server_hash:
|
||||
self.refresh_client()
|
||||
self.server_hash = server_hash
|
||||
else:
|
||||
self.reset_session()
|
||||
|
||||
def refresh_client(self):
|
||||
"""
|
||||
Ensure every client call is independent
|
||||
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
|
||||
Returns:
|
||||
"""
|
||||
# need session hash to be new every time, to avoid "generator already executing"
|
||||
self.reset_session()
|
||||
|
||||
client = Client(*self.args, **self.kwargs)
|
||||
for k, v in client.__dict__.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def submit(
|
||||
self,
|
||||
*args,
|
||||
api_name: str | None = None,
|
||||
fn_index: int | None = None,
|
||||
result_callbacks: Callable | list[Callable] | None = None,
|
||||
) -> Job:
|
||||
# Note predict calls submit
|
||||
try:
|
||||
self.refresh_client_if_should()
|
||||
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
||||
except Exception as e:
|
||||
print("Hit e=%s" % str(e), flush=True)
|
||||
# force reconfig in case only that
|
||||
self.refresh_client()
|
||||
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
||||
|
||||
# see if immediately failed
|
||||
e = job.future._exception
|
||||
if e is not None:
|
||||
print(
|
||||
"GR job failed: %s %s"
|
||||
% (str(e), "".join(traceback.format_tb(e.__traceback__))),
|
||||
flush=True,
|
||||
)
|
||||
# force reconfig in case only that
|
||||
self.refresh_client()
|
||||
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
||||
e2 = job.future._exception
|
||||
if e2 is not None:
|
||||
print(
|
||||
"GR job failed again: %s\n%s"
|
||||
% (
|
||||
str(e2),
|
||||
"".join(traceback.format_tb(e2.__traceback__)),
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return job
|
||||
@@ -1,765 +0,0 @@
|
||||
import os
|
||||
from apps.stable_diffusion.src.utils.utils import _compile_module
|
||||
from io import BytesIO
|
||||
import torch_mlir
|
||||
|
||||
from stopping import get_stopping
|
||||
from prompter import Prompter, PromptType
|
||||
|
||||
from transformers import TextGenerationPipeline
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
from transformers.generation import (
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
import copy
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
import gc
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
# Brevitas
|
||||
from typing import List, Tuple
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
return [lhs[0], rhs[0]]
|
||||
else:
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
|
||||
global_device = "cuda"
|
||||
global_precision = "fp16"
|
||||
|
||||
if not args.run_docuchat_web:
|
||||
args.device = global_device
|
||||
args.precision = global_precision
|
||||
tensor_device = "cpu" if args.device == "cpu" else "cuda"
|
||||
|
||||
|
||||
class H2OGPTModel(torch.nn.Module):
|
||||
def __init__(self, device, precision):
|
||||
super().__init__()
|
||||
torch_dtype = (
|
||||
torch.float32
|
||||
if precision == "fp32" or device == "cpu"
|
||||
else torch.float16
|
||||
)
|
||||
device_map = {"": "cpu"} if device == "cpu" else {"": 0}
|
||||
model_kwargs = {
|
||||
"local_files_only": False,
|
||||
"torch_dtype": torch_dtype,
|
||||
"resume_download": True,
|
||||
"use_auth_token": False,
|
||||
"trust_remote_code": True,
|
||||
"offload_folder": "offline_folder",
|
||||
"device_map": device_map,
|
||||
}
|
||||
config = AutoConfig.from_pretrained(
|
||||
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
use_auth_token=False,
|
||||
trust_remote_code=True,
|
||||
offload_folder="offline_folder",
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
config=config,
|
||||
**model_kwargs,
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model.transformer.h,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=128,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
}
|
||||
output = self.model(
|
||||
**input_dict,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return output.logits[:, -1, :]
|
||||
|
||||
|
||||
class H2OGPTSHARKModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
model_name = "h2ogpt_falcon_7b"
|
||||
extended_model_name = (
|
||||
model_name + "_" + args.precision + "_" + args.device
|
||||
)
|
||||
vmfb_path = Path(extended_model_name + ".vmfb")
|
||||
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
|
||||
shark_module = None
|
||||
|
||||
need_to_compile = False
|
||||
if not vmfb_path.exists():
|
||||
need_to_compile = True
|
||||
# Downloading VMFB from shark_tank
|
||||
print("Trying to download pre-compiled vmfb from shark tank.")
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(vmfb_path),
|
||||
vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
print(
|
||||
"Pre-compiled vmfb downloaded from shark tank successfully."
|
||||
)
|
||||
need_to_compile = False
|
||||
|
||||
if need_to_compile:
|
||||
if not mlir_path.exists():
|
||||
print("Trying to download pre-generated mlir from shark tank.")
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(mlir_path),
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
# Generating the mlir
|
||||
bytecode = self.get_bytecode(tensor_device, args.precision)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
print(f"[DEBUG] generating vmfb.")
|
||||
shark_module = _compile_module(
|
||||
shark_module, extended_model_name, []
|
||||
)
|
||||
print("Saved newly generated vmfb.")
|
||||
|
||||
if shark_module is None:
|
||||
if vmfb_path.exists():
|
||||
print("Compiled vmfb found. Loading it from: ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.load_module(str(vmfb_path))
|
||||
print("Compiled vmfb loaded successfully.")
|
||||
else:
|
||||
raise ValueError("Unable to download/generate a vmfb.")
|
||||
|
||||
self.model = shark_module
|
||||
|
||||
def get_bytecode(self, device, precision):
|
||||
h2ogpt_model = H2OGPTModel(device, precision)
|
||||
|
||||
compilation_input_ids = torch.randint(
|
||||
low=1, high=10000, size=(1, 400)
|
||||
).to(device=device)
|
||||
compilation_attention_mask = torch.ones(1, 400, dtype=torch.int64).to(
|
||||
device=device
|
||||
)
|
||||
|
||||
h2ogptCompileInput = (
|
||||
compilation_input_ids,
|
||||
compilation_attention_mask,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
h2ogpt_model,
|
||||
h2ogptCompileInput,
|
||||
is_f16=False,
|
||||
precision=precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del h2ogpt_model
|
||||
del self.src_model
|
||||
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if precision in ["int4", "int8"]:
|
||||
from torch_mlir.compiler_utils import (
|
||||
run_pipeline_with_repro_report,
|
||||
)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
bytecode = save_mlir(
|
||||
bytecode,
|
||||
model_name=f"h2ogpt_{precision}",
|
||||
frontend="torch",
|
||||
)
|
||||
return bytecode
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
result = torch.from_numpy(
|
||||
self.model(
|
||||
"forward",
|
||||
(input_ids.to(device="cpu"), attention_mask.to(device="cpu")),
|
||||
)
|
||||
).to(device=tensor_device)
|
||||
return result
|
||||
|
||||
|
||||
def decode_tokens(tokenizer, res_tokens):
|
||||
for i in range(len(res_tokens)):
|
||||
if type(res_tokens[i]) != int:
|
||||
res_tokens[i] = int(res_tokens[i][0])
|
||||
|
||||
res_str = tokenizer.decode(res_tokens, skip_special_tokens=True)
|
||||
return res_str
|
||||
|
||||
|
||||
def generate_token(h2ogpt_shark_model, model, tokenizer, **generate_kwargs):
|
||||
del generate_kwargs["max_time"]
|
||||
generate_kwargs["input_ids"] = generate_kwargs["input_ids"].to(
|
||||
device=tensor_device
|
||||
)
|
||||
generate_kwargs["attention_mask"] = generate_kwargs["attention_mask"].to(
|
||||
device=tensor_device
|
||||
)
|
||||
truncated_input_ids = []
|
||||
stopping_criteria = generate_kwargs["stopping_criteria"]
|
||||
|
||||
generation_config_ = GenerationConfig.from_model_config(model.config)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
model_kwargs,
|
||||
) = model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs[
|
||||
"output_hidden_states"
|
||||
] = generation_config.output_hidden_states
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length
|
||||
)
|
||||
|
||||
logits_processor = model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
stopping_criteria = model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
logits_warper = model._get_logits_warper(generation_config)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
model_kwargs,
|
||||
) = model._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=model.config.is_encoder_decoder, # False
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id).to(device=tensor_device)
|
||||
if eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0],
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
timesRan = 0
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
print("\n")
|
||||
|
||||
res_tokens = []
|
||||
while True:
|
||||
model_inputs = model.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs
|
||||
)
|
||||
|
||||
outputs = h2ogpt_shark_model.forward(
|
||||
model_inputs["input_ids"], model_inputs["attention_mask"]
|
||||
)
|
||||
|
||||
if args.precision == "fp16":
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = next_token * unfinished_sequences + pad_token_id * (
|
||||
1 - unfinished_sequences
|
||||
)
|
||||
|
||||
input_ids = torch.cat([input_ids, next_token[:, None]], dim=-1)
|
||||
|
||||
model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
truncated_input_ids.append(input_ids[:, 0])
|
||||
input_ids = input_ids[:, 1:]
|
||||
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 1:]
|
||||
|
||||
new_word = tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
res_tokens.append(next_token)
|
||||
if new_word == "<0x0A>":
|
||||
print("\n", end="", flush=True)
|
||||
else:
|
||||
print(f"{new_word}", end=" ", flush=True)
|
||||
|
||||
part_str = decode_tokens(tokenizer, res_tokens)
|
||||
yield part_str
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_token.tile(eos_token_id_tensor.shape[0], 1)
|
||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(
|
||||
input_ids, scores
|
||||
):
|
||||
break
|
||||
timesRan = timesRan + 1
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / timesRan
|
||||
)
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
res_str = decode_tokens(tokenizer, res_tokens)
|
||||
yield res_str
|
||||
|
||||
|
||||
def pad_or_truncate_inputs(
|
||||
input_ids, attention_mask, max_padding_length=400, do_truncation=False
|
||||
):
|
||||
inp_shape = input_ids.shape
|
||||
if inp_shape[1] < max_padding_length:
|
||||
# do padding
|
||||
num_add_token = max_padding_length - inp_shape[1]
|
||||
padded_input_ids = torch.cat(
|
||||
[
|
||||
torch.tensor([[11] * num_add_token]).to(device=tensor_device),
|
||||
input_ids,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
padded_attention_mask = torch.cat(
|
||||
[
|
||||
torch.tensor([[0] * num_add_token]).to(device=tensor_device),
|
||||
attention_mask,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return padded_input_ids, padded_attention_mask
|
||||
elif inp_shape[1] > max_padding_length or do_truncation:
|
||||
# do truncation
|
||||
num_remove_token = inp_shape[1] - max_padding_length
|
||||
truncated_input_ids = input_ids[:, num_remove_token:]
|
||||
truncated_attention_mask = attention_mask[:, num_remove_token:]
|
||||
return truncated_input_ids, truncated_attention_mask
|
||||
else:
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
debug=False,
|
||||
chat=False,
|
||||
stream_output=False,
|
||||
sanitize_bot_response=False,
|
||||
use_prompter=True,
|
||||
prompter=None,
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
max_input_tokens=2048 - 256,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
HF-like pipeline, but handle instruction prompting and stopping (for some models)
|
||||
:param args:
|
||||
:param debug:
|
||||
:param chat:
|
||||
:param stream_output:
|
||||
:param sanitize_bot_response:
|
||||
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
|
||||
:param prompter: prompter, can pass if have already
|
||||
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
|
||||
If use_prompter, then will make prompter and use it.
|
||||
:param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
|
||||
:param max_input_tokens:
|
||||
:param kwargs:
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt_text = None
|
||||
self.use_prompter = use_prompter
|
||||
self.prompt_type = prompt_type
|
||||
self.prompt_dict = prompt_dict
|
||||
self.prompter = prompter
|
||||
if self.use_prompter:
|
||||
if self.prompter is not None:
|
||||
assert self.prompter.prompt_type is not None
|
||||
else:
|
||||
self.prompter = Prompter(
|
||||
self.prompt_type,
|
||||
self.prompt_dict,
|
||||
debug=debug,
|
||||
chat=chat,
|
||||
stream_output=stream_output,
|
||||
)
|
||||
self.human = self.prompter.humanstr
|
||||
self.bot = self.prompter.botstr
|
||||
self.can_stop = True
|
||||
else:
|
||||
self.prompter = None
|
||||
self.human = None
|
||||
self.bot = None
|
||||
self.can_stop = False
|
||||
self.sanitize_bot_response = sanitize_bot_response
|
||||
self.max_input_tokens = (
|
||||
max_input_tokens # not for generate, so ok that not kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
|
||||
verbose = bool(int(os.getenv("VERBOSE_PIPELINE", "0")))
|
||||
|
||||
if hasattr(tokenizer, "model_max_length"):
|
||||
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
|
||||
model_max_length = tokenizer.model_max_length
|
||||
if max_prompt_length is not None:
|
||||
model_max_length = min(model_max_length, max_prompt_length)
|
||||
# cut at some upper likely limit to avoid excessive tokenization etc
|
||||
# upper bound of 10 chars/token, e.g. special chars sometimes are long
|
||||
if len(prompt_text) > model_max_length * 10:
|
||||
len0 = len(prompt_text)
|
||||
prompt_text = prompt_text[-model_max_length * 10 :]
|
||||
if verbose:
|
||||
print(
|
||||
"Cut of input: %s -> %s" % (len0, len(prompt_text)),
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
# unknown
|
||||
model_max_length = None
|
||||
|
||||
num_prompt_tokens = None
|
||||
if model_max_length is not None:
|
||||
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
|
||||
# For https://github.com/h2oai/h2ogpt/issues/192
|
||||
for trial in range(0, 3):
|
||||
prompt_tokens = tokenizer(prompt_text)["input_ids"]
|
||||
num_prompt_tokens = len(prompt_tokens)
|
||||
if num_prompt_tokens > model_max_length:
|
||||
# conservative by using int()
|
||||
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
|
||||
# keep tail, where question is if using langchain
|
||||
prompt_text = prompt_text[
|
||||
-model_max_length * chars_per_token :
|
||||
]
|
||||
if verbose:
|
||||
print(
|
||||
"reducing %s tokens, assuming average of %s chars/token for %s characters"
|
||||
% (
|
||||
num_prompt_tokens,
|
||||
chars_per_token,
|
||||
len(prompt_text),
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"using %s tokens with %s chars"
|
||||
% (num_prompt_tokens, len(prompt_text)),
|
||||
flush=True,
|
||||
)
|
||||
break
|
||||
|
||||
return prompt_text, num_prompt_tokens
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
prompt_text,
|
||||
prefix="",
|
||||
handle_long_generation=None,
|
||||
**generate_kwargs,
|
||||
):
|
||||
(
|
||||
prompt_text,
|
||||
num_prompt_tokens,
|
||||
) = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
||||
|
||||
data_point = dict(context="", instruction=prompt_text, input="")
|
||||
if self.prompter is not None:
|
||||
prompt_text = self.prompter.generate_prompt(data_point)
|
||||
self.prompt_text = prompt_text
|
||||
if handle_long_generation is None:
|
||||
# forces truncation of inputs to avoid critical failure
|
||||
handle_long_generation = None # disable with new approaches
|
||||
return super().preprocess(
|
||||
prompt_text,
|
||||
prefix=prefix,
|
||||
handle_long_generation=handle_long_generation,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
model_outputs,
|
||||
return_type=ReturnType.FULL_TEXT,
|
||||
clean_up_tokenization_spaces=True,
|
||||
):
|
||||
records = super().postprocess(
|
||||
model_outputs,
|
||||
return_type=return_type,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
for rec in records:
|
||||
if self.use_prompter:
|
||||
outputs = rec["generated_text"]
|
||||
outputs = self.prompter.get_response(
|
||||
outputs,
|
||||
prompt=self.prompt_text,
|
||||
sanitize_bot_response=self.sanitize_bot_response,
|
||||
)
|
||||
elif self.bot and self.human:
|
||||
outputs = (
|
||||
rec["generated_text"]
|
||||
.split(self.bot)[1]
|
||||
.split(self.human)[0]
|
||||
)
|
||||
else:
|
||||
outputs = rec["generated_text"]
|
||||
rec["generated_text"] = outputs
|
||||
print(
|
||||
"prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs),
|
||||
flush=True,
|
||||
)
|
||||
return records
|
||||
|
||||
def _forward(self, model_inputs, **generate_kwargs):
|
||||
if self.can_stop:
|
||||
stopping_criteria = get_stopping(
|
||||
self.prompt_type,
|
||||
self.prompt_dict,
|
||||
self.tokenizer,
|
||||
self.device,
|
||||
human=self.human,
|
||||
bot=self.bot,
|
||||
model_max_length=self.tokenizer.model_max_length,
|
||||
)
|
||||
generate_kwargs["stopping_criteria"] = stopping_criteria
|
||||
# return super()._forward(model_inputs, **generate_kwargs)
|
||||
return self.__forward(model_inputs, **generate_kwargs)
|
||||
|
||||
# FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
|
||||
# FIXME: https://github.com/h2oai/h2ogpt/issues/172
|
||||
def __forward(self, model_inputs, **generate_kwargs):
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs.get("attention_mask", None)
|
||||
# Allow empty prompts
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
in_b = 1
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
prompt_text = model_inputs.pop("prompt_text")
|
||||
|
||||
## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
|
||||
## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
|
||||
# generate_kwargs = copy.deepcopy(generate_kwargs)
|
||||
prefix_length = generate_kwargs.pop("prefix_length", 0)
|
||||
if prefix_length > 0:
|
||||
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
and generate_kwargs["generation_config"].max_new_tokens
|
||||
is not None
|
||||
)
|
||||
if not has_max_new_tokens:
|
||||
generate_kwargs["max_length"] = (
|
||||
generate_kwargs.get("max_length")
|
||||
or self.model.config.max_length
|
||||
)
|
||||
generate_kwargs["max_length"] += prefix_length
|
||||
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
and generate_kwargs["generation_config"].min_new_tokens
|
||||
is not None
|
||||
)
|
||||
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
||||
generate_kwargs["min_length"] += prefix_length
|
||||
|
||||
# BS x SL
|
||||
# pad or truncate the input_ids and attention_mask
|
||||
max_padding_length = 400
|
||||
input_ids, attention_mask = pad_or_truncate_inputs(
|
||||
input_ids, attention_mask, max_padding_length=max_padding_length
|
||||
)
|
||||
|
||||
return_dict = {
|
||||
"model": self.model,
|
||||
"tokenizer": self.tokenizer,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return_dict = {**return_dict, **generate_kwargs}
|
||||
return return_dict
|
||||
@@ -1,247 +0,0 @@
|
||||
"""
|
||||
Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py
|
||||
But accepts preloaded model to avoid slowness in use and CUDA forking issues
|
||||
|
||||
Loader that loads image captions
|
||||
By default, the loader utilizes the pre-trained BLIP image captioning model.
|
||||
https://huggingface.co/Salesforce/blip-image-captioning-base
|
||||
|
||||
"""
|
||||
from typing import List, Union, Any, Tuple
|
||||
|
||||
import requests
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders import ImageCaptionLoader
|
||||
|
||||
from utils import get_device, NullContext
|
||||
|
||||
import pkg_resources
|
||||
|
||||
try:
|
||||
assert pkg_resources.get_distribution("bitsandbytes") is not None
|
||||
have_bitsandbytes = True
|
||||
except (pkg_resources.DistributionNotFound, AssertionError):
|
||||
have_bitsandbytes = False
|
||||
|
||||
|
||||
class H2OImageCaptionLoader(ImageCaptionLoader):
|
||||
"""Loader that loads the captions of an image"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path_images: Union[str, List[str]] = None,
|
||||
blip_processor: str = None,
|
||||
blip_model: str = None,
|
||||
caption_gpu=True,
|
||||
load_in_8bit=True,
|
||||
# True doesn't seem to work, even though https://huggingface.co/Salesforce/blip2-flan-t5-xxl#in-8-bit-precision-int8
|
||||
load_half=False,
|
||||
load_gptq="",
|
||||
use_safetensors=False,
|
||||
min_new_tokens=20,
|
||||
max_tokens=50,
|
||||
):
|
||||
if blip_model is None or blip_model is None:
|
||||
blip_processor = "Salesforce/blip-image-captioning-base"
|
||||
blip_model = "Salesforce/blip-image-captioning-base"
|
||||
|
||||
super().__init__(path_images, blip_processor, blip_model)
|
||||
self.blip_processor = blip_processor
|
||||
self.blip_model = blip_model
|
||||
self.processor = None
|
||||
self.model = None
|
||||
self.caption_gpu = caption_gpu
|
||||
self.context_class = NullContext
|
||||
self.device = "cpu"
|
||||
self.load_in_8bit = (
|
||||
load_in_8bit and have_bitsandbytes
|
||||
) # only for blip2
|
||||
self.load_half = load_half
|
||||
self.load_gptq = load_gptq
|
||||
self.use_safetensors = use_safetensors
|
||||
self.gpu_id = "auto"
|
||||
# default prompt
|
||||
self.prompt = "image of"
|
||||
self.min_new_tokens = min_new_tokens
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def set_context(self):
|
||||
if get_device() == "cuda" and self.caption_gpu:
|
||||
import torch
|
||||
|
||||
n_gpus = (
|
||||
torch.cuda.device_count() if torch.cuda.is_available else 0
|
||||
)
|
||||
if n_gpus > 0:
|
||||
self.context_class = torch.device
|
||||
self.device = "cuda"
|
||||
|
||||
def load_model(self):
|
||||
try:
|
||||
import transformers
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"`transformers` package not found, please install with "
|
||||
"`pip install transformers`."
|
||||
)
|
||||
self.set_context()
|
||||
if self.caption_gpu:
|
||||
if self.gpu_id == "auto":
|
||||
# blip2 has issues with multi-GPU. Error says need to somehow set language model in device map
|
||||
# device_map = 'auto'
|
||||
device_map = {"": 0}
|
||||
else:
|
||||
if self.device == "cuda":
|
||||
device_map = {"": self.gpu_id}
|
||||
else:
|
||||
device_map = {"": "cpu"}
|
||||
else:
|
||||
device_map = {"": "cpu"}
|
||||
import torch
|
||||
|
||||
with torch.no_grad():
|
||||
with self.context_class(self.device):
|
||||
context_class_cast = (
|
||||
NullContext if self.device == "cpu" else torch.autocast
|
||||
)
|
||||
with context_class_cast(self.device):
|
||||
if "blip2" in self.blip_processor.lower():
|
||||
from transformers import (
|
||||
Blip2Processor,
|
||||
Blip2ForConditionalGeneration,
|
||||
)
|
||||
|
||||
if self.load_half and not self.load_in_8bit:
|
||||
self.processor = Blip2Processor.from_pretrained(
|
||||
self.blip_processor, device_map=device_map
|
||||
).half()
|
||||
self.model = (
|
||||
Blip2ForConditionalGeneration.from_pretrained(
|
||||
self.blip_model, device_map=device_map
|
||||
).half()
|
||||
)
|
||||
else:
|
||||
self.processor = Blip2Processor.from_pretrained(
|
||||
self.blip_processor,
|
||||
load_in_8bit=self.load_in_8bit,
|
||||
device_map=device_map,
|
||||
)
|
||||
self.model = (
|
||||
Blip2ForConditionalGeneration.from_pretrained(
|
||||
self.blip_model,
|
||||
load_in_8bit=self.load_in_8bit,
|
||||
device_map=device_map,
|
||||
)
|
||||
)
|
||||
else:
|
||||
from transformers import (
|
||||
BlipForConditionalGeneration,
|
||||
BlipProcessor,
|
||||
)
|
||||
|
||||
self.load_half = False # not supported
|
||||
if self.caption_gpu:
|
||||
if device_map == "auto":
|
||||
# Blip doesn't support device_map='auto'
|
||||
if self.device == "cuda":
|
||||
if self.gpu_id == "auto":
|
||||
device_map = {"": 0}
|
||||
else:
|
||||
device_map = {"": self.gpu_id}
|
||||
else:
|
||||
device_map = {"": "cpu"}
|
||||
else:
|
||||
device_map = {"": "cpu"}
|
||||
self.processor = BlipProcessor.from_pretrained(
|
||||
self.blip_processor, device_map=device_map
|
||||
)
|
||||
self.model = (
|
||||
BlipForConditionalGeneration.from_pretrained(
|
||||
self.blip_model, device_map=device_map
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def set_image_paths(self, path_images: Union[str, List[str]]):
|
||||
"""
|
||||
Load from a list of image files
|
||||
"""
|
||||
if isinstance(path_images, str):
|
||||
self.image_paths = [path_images]
|
||||
else:
|
||||
self.image_paths = path_images
|
||||
|
||||
def load(self, prompt=None) -> List[Document]:
|
||||
if self.processor is None or self.model is None:
|
||||
self.load_model()
|
||||
results = []
|
||||
for path_image in self.image_paths:
|
||||
caption, metadata = self._get_captions_and_metadata(
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
path_image=path_image,
|
||||
prompt=prompt,
|
||||
)
|
||||
doc = Document(page_content=caption, metadata=metadata)
|
||||
results.append(doc)
|
||||
|
||||
return results
|
||||
|
||||
def _get_captions_and_metadata(
|
||||
self, model: Any, processor: Any, path_image: str, prompt=None
|
||||
) -> Tuple[str, dict]:
|
||||
"""
|
||||
Helper function for getting the captions and metadata of an image
|
||||
"""
|
||||
if prompt is None:
|
||||
prompt = self.prompt
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"`PIL` package not found, please install with `pip install pillow`"
|
||||
)
|
||||
|
||||
try:
|
||||
if path_image.startswith("http://") or path_image.startswith(
|
||||
"https://"
|
||||
):
|
||||
image = Image.open(
|
||||
requests.get(path_image, stream=True).raw
|
||||
).convert("RGB")
|
||||
else:
|
||||
image = Image.open(path_image).convert("RGB")
|
||||
except Exception:
|
||||
raise ValueError(f"Could not get image data for {path_image}")
|
||||
|
||||
import torch
|
||||
|
||||
with torch.no_grad():
|
||||
with self.context_class(self.device):
|
||||
context_class_cast = (
|
||||
NullContext if self.device == "cpu" else torch.autocast
|
||||
)
|
||||
with context_class_cast(self.device):
|
||||
if self.load_half:
|
||||
inputs = processor(
|
||||
image, prompt, return_tensors="pt"
|
||||
).half()
|
||||
else:
|
||||
inputs = processor(image, prompt, return_tensors="pt")
|
||||
min_length = len(prompt) // 4 + self.min_new_tokens
|
||||
self.max_tokens = max(self.max_tokens, min_length)
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
min_length=min_length,
|
||||
max_length=self.max_tokens,
|
||||
)
|
||||
|
||||
caption: str = processor.decode(
|
||||
output[0], skip_special_tokens=True
|
||||
)
|
||||
prompti = caption.find(prompt)
|
||||
if prompti >= 0:
|
||||
caption = caption[prompti + len(prompt) :]
|
||||
metadata: dict = {"image_path": path_image}
|
||||
|
||||
return caption, metadata
|
||||
@@ -1,120 +0,0 @@
|
||||
# for generate (gradio server) and finetune
|
||||
datasets==2.13.0
|
||||
sentencepiece==0.1.99
|
||||
huggingface_hub==0.16.4
|
||||
appdirs==1.4.4
|
||||
fire==0.5.0
|
||||
docutils==0.20.1
|
||||
evaluate==0.4.0
|
||||
rouge_score==0.1.2
|
||||
sacrebleu==2.3.1
|
||||
scikit-learn==1.2.2
|
||||
alt-profanity-check==1.2.2
|
||||
better-profanity==0.7.0
|
||||
numpy==1.24.3
|
||||
pandas==2.0.2
|
||||
matplotlib==3.7.1
|
||||
loralib==0.1.1
|
||||
bitsandbytes==0.39.0
|
||||
accelerate==0.20.3
|
||||
peft==0.4.0
|
||||
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
|
||||
transformers==4.30.2
|
||||
tokenizers==0.13.3
|
||||
APScheduler==3.10.1
|
||||
|
||||
# optional for generate
|
||||
pynvml==11.5.0
|
||||
psutil==5.9.5
|
||||
boto3==1.26.101
|
||||
botocore==1.29.101
|
||||
|
||||
# optional for finetune
|
||||
tensorboard==2.13.0
|
||||
neptune==1.2.0
|
||||
|
||||
# for gradio client
|
||||
gradio_client==0.2.10
|
||||
beautifulsoup4==4.12.2
|
||||
markdown==3.4.3
|
||||
|
||||
# data and testing
|
||||
pytest==7.2.2
|
||||
pytest-xdist==3.2.1
|
||||
nltk==3.8.1
|
||||
textstat==0.7.3
|
||||
# pandoc==2.3
|
||||
pypandoc==1.11; sys_platform == "darwin" and platform_machine == "arm64"
|
||||
pypandoc_binary==1.11; platform_machine == "x86_64"
|
||||
pypandoc_binary==1.11; sys_platform == "win32"
|
||||
openpyxl==3.1.2
|
||||
lm_dataformat==0.0.20
|
||||
bioc==2.0
|
||||
|
||||
# falcon
|
||||
einops==0.6.1
|
||||
instructorembedding==1.0.1
|
||||
|
||||
# for gpt4all .env file, but avoid worrying about imports
|
||||
python-dotenv==1.0.0
|
||||
|
||||
text-generation==0.6.0
|
||||
# for tokenization when don't have HF tokenizer
|
||||
tiktoken==0.4.0
|
||||
# optional: for OpenAI endpoint or embeddings (requires key)
|
||||
openai==0.27.8
|
||||
|
||||
# optional for chat with PDF
|
||||
langchain==0.0.329
|
||||
pypdf==3.17.0
|
||||
# avoid textract, requires old six
|
||||
#textract==1.6.5
|
||||
|
||||
# for HF embeddings
|
||||
sentence_transformers==2.2.2
|
||||
|
||||
# local vector db
|
||||
chromadb==0.3.25
|
||||
# server vector db
|
||||
#pymilvus==2.2.8
|
||||
|
||||
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
||||
# unstructured==0.8.1
|
||||
|
||||
# strong support for images
|
||||
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
||||
unstructured[local-inference]==0.7.4
|
||||
#pdf2image==1.16.3
|
||||
#pytesseract==0.3.10
|
||||
pillow
|
||||
|
||||
pdfminer.six==20221105
|
||||
urllib3
|
||||
requests_file
|
||||
|
||||
#pdf2image==1.16.3
|
||||
#pytesseract==0.3.10
|
||||
tabulate==0.9.0
|
||||
# FYI pandoc already part of requirements.txt
|
||||
|
||||
# JSONLoader, but makes some trouble for some users
|
||||
# jq==1.4.1
|
||||
|
||||
# to check licenses
|
||||
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
|
||||
pip-licenses==4.3.0
|
||||
|
||||
# weaviate vector db
|
||||
weaviate-client==3.22.1
|
||||
|
||||
gpt4all==1.0.5
|
||||
llama-cpp-python==0.1.73
|
||||
|
||||
arxiv==1.4.8
|
||||
pymupdf==1.22.5 # AGPL license
|
||||
# extract-msg==0.41.1 # GPL3
|
||||
|
||||
# sometimes unstructured fails, these work in those cases. See https://github.com/h2oai/h2ogpt/issues/320
|
||||
playwright==1.36.0
|
||||
# requires Chrome binary to be in path
|
||||
selenium==4.10.0
|
||||
@@ -1,124 +0,0 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[
|
||||
torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]
|
||||
]:
|
||||
"""Input shape: Batch x Time x Channel
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
assert past_key_value is None, "past_key_value is not supported"
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
assert not output_attentions, "output_attentions is not supported"
|
||||
assert not use_cache, "use_cache is not supported"
|
||||
|
||||
# Flash attention codes from
|
||||
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
||||
|
||||
# transform the data into the format required by flash attention
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask = attention_mask
|
||||
|
||||
if key_padding_mask is None:
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
max_s = q_len
|
||||
cu_q_lens = torch.arange(
|
||||
0,
|
||||
(bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=qkv.device,
|
||||
)
|
||||
output = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
else:
|
||||
nheads = qkv.shape[-2]
|
||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||
x_unpad = rearrange(
|
||||
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
||||
)
|
||||
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
||||
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(
|
||||
pad_input(
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
||||
indices,
|
||||
bsz,
|
||||
q_len,
|
||||
),
|
||||
"b s (h d) -> b s h d",
|
||||
h=nheads,
|
||||
)
|
||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# [bsz, seq_len]
|
||||
return attention_mask
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn():
|
||||
print(
|
||||
"Replacing original LLaMa attention with flash attention", flush=True
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
@@ -1,109 +0,0 @@
|
||||
import functools
|
||||
|
||||
|
||||
def get_loaders(model_name, reward_type, llama_type=None, load_gptq=""):
|
||||
# NOTE: Some models need specific new prompt_type
|
||||
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
||||
if load_gptq:
|
||||
from transformers import AutoTokenizer
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
|
||||
use_triton = False
|
||||
functools.partial(
|
||||
AutoGPTQForCausalLM.from_quantized,
|
||||
quantize_config=None,
|
||||
use_triton=use_triton,
|
||||
)
|
||||
return AutoGPTQForCausalLM.from_quantized, AutoTokenizer
|
||||
if llama_type is None:
|
||||
llama_type = "llama" in model_name.lower()
|
||||
if llama_type:
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
return LlamaForCausalLM.from_pretrained, LlamaTokenizer
|
||||
elif "distilgpt2" in model_name.lower():
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
return AutoModelForCausalLM.from_pretrained, AutoTokenizer
|
||||
elif "gpt2" in model_name.lower():
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
return GPT2LMHeadModel.from_pretrained, GPT2Tokenizer
|
||||
elif "mbart-" in model_name.lower():
|
||||
from transformers import (
|
||||
MBartForConditionalGeneration,
|
||||
MBart50TokenizerFast,
|
||||
)
|
||||
|
||||
return (
|
||||
MBartForConditionalGeneration.from_pretrained,
|
||||
MBart50TokenizerFast,
|
||||
)
|
||||
elif (
|
||||
"t5" == model_name.lower()
|
||||
or "t5-" in model_name.lower()
|
||||
or "flan-" in model_name.lower()
|
||||
):
|
||||
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
||||
|
||||
return T5ForConditionalGeneration.from_pretrained, AutoTokenizer
|
||||
elif "bigbird" in model_name:
|
||||
from transformers import (
|
||||
BigBirdPegasusForConditionalGeneration,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
return (
|
||||
BigBirdPegasusForConditionalGeneration.from_pretrained,
|
||||
AutoTokenizer,
|
||||
)
|
||||
elif (
|
||||
"bart-large-cnn-samsum" in model_name
|
||||
or "flan-t5-base-samsum" in model_name
|
||||
):
|
||||
from transformers import pipeline
|
||||
|
||||
return pipeline, "summarization"
|
||||
elif (
|
||||
reward_type
|
||||
or "OpenAssistant/reward-model".lower() in model_name.lower()
|
||||
):
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
return (
|
||||
AutoModelForSequenceClassification.from_pretrained,
|
||||
AutoTokenizer,
|
||||
)
|
||||
else:
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
model_loader = AutoModelForCausalLM
|
||||
tokenizer_loader = AutoTokenizer
|
||||
return model_loader.from_pretrained, tokenizer_loader
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
tokenizer_loader,
|
||||
tokenizer_base_model,
|
||||
local_files_only,
|
||||
resume_download,
|
||||
use_auth_token,
|
||||
):
|
||||
tokenizer = tokenizer_loader.from_pretrained(
|
||||
tokenizer_base_model,
|
||||
local_files_only=local_files_only,
|
||||
resume_download=resume_download,
|
||||
use_auth_token=use_auth_token,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
tokenizer.pad_token_id = 0 # different from the eos token
|
||||
# when generating, we will use the logits of right-most token to predict the next token
|
||||
# so the padding should be on the left,
|
||||
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
|
||||
tokenizer.padding_side = "left" # Allow batched inference
|
||||
|
||||
return tokenizer
|
||||
@@ -1,203 +0,0 @@
|
||||
import os
|
||||
|
||||
from gpt_langchain import (
|
||||
path_to_docs,
|
||||
get_some_dbs_from_hf,
|
||||
all_db_zips,
|
||||
some_db_zips,
|
||||
create_or_update_db,
|
||||
)
|
||||
from utils import get_ngpus_vis
|
||||
|
||||
|
||||
def glob_to_db(
|
||||
user_path,
|
||||
chunk=True,
|
||||
chunk_size=512,
|
||||
verbose=False,
|
||||
fail_any_exception=False,
|
||||
n_jobs=-1,
|
||||
url=None,
|
||||
enable_captions=True,
|
||||
captions_model=None,
|
||||
caption_loader=None,
|
||||
enable_ocr=False,
|
||||
):
|
||||
sources1 = path_to_docs(
|
||||
user_path,
|
||||
verbose=verbose,
|
||||
fail_any_exception=fail_any_exception,
|
||||
n_jobs=n_jobs,
|
||||
chunk=chunk,
|
||||
chunk_size=chunk_size,
|
||||
url=url,
|
||||
enable_captions=enable_captions,
|
||||
captions_model=captions_model,
|
||||
caption_loader=caption_loader,
|
||||
enable_ocr=enable_ocr,
|
||||
)
|
||||
return sources1
|
||||
|
||||
|
||||
def make_db_main(
|
||||
use_openai_embedding: bool = False,
|
||||
hf_embedding_model: str = None,
|
||||
persist_directory: str = "db_dir_UserData",
|
||||
user_path: str = "user_path",
|
||||
url: str = None,
|
||||
add_if_exists: bool = True,
|
||||
collection_name: str = "UserData",
|
||||
verbose: bool = False,
|
||||
chunk: bool = True,
|
||||
chunk_size: int = 512,
|
||||
fail_any_exception: bool = False,
|
||||
download_all: bool = False,
|
||||
download_some: bool = False,
|
||||
download_one: str = None,
|
||||
download_dest: str = "./",
|
||||
n_jobs: int = -1,
|
||||
enable_captions: bool = True,
|
||||
captions_model: str = "Salesforce/blip-image-captioning-base",
|
||||
pre_load_caption_model: bool = False,
|
||||
caption_gpu: bool = True,
|
||||
enable_ocr: bool = False,
|
||||
db_type: str = "chroma",
|
||||
):
|
||||
"""
|
||||
# To make UserData db for generate.py, put pdfs, etc. into path user_path and run:
|
||||
python make_db.py
|
||||
|
||||
# once db is made, can use in generate.py like:
|
||||
|
||||
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b --langchain_mode=UserData
|
||||
|
||||
or zip-up the db_dir_UserData and share:
|
||||
|
||||
zip -r db_dir_UserData.zip db_dir_UserData
|
||||
|
||||
# To get all db files (except large wiki_full) do:
|
||||
python make_db.py --download_some=True
|
||||
|
||||
# To get a single db file from HF:
|
||||
python make_db.py --download_one=db_dir_DriverlessAI_docs.zip
|
||||
|
||||
:param use_openai_embedding: Whether to use OpenAI embedding
|
||||
:param hf_embedding_model: HF embedding model to use. Like generate.py, uses 'hkunlp/instructor-large' if have GPUs, else "sentence-transformers/all-MiniLM-L6-v2"
|
||||
:param persist_directory: where to persist db
|
||||
:param user_path: where to pull documents from (None means url is not None. If url is not None, this is ignored.)
|
||||
:param url: url to generate documents from (None means user_path is not None)
|
||||
:param add_if_exists: Add to db if already exists, but will not add duplicate sources
|
||||
:param collection_name: Collection name for new db if not adding
|
||||
:param verbose: whether to show verbose messages
|
||||
:param chunk: whether to chunk data
|
||||
:param chunk_size: chunk size for chunking
|
||||
:param fail_any_exception: whether to fail if any exception hit during ingestion of files
|
||||
:param download_all: whether to download all (including 23GB Wikipedia) example databases from h2o.ai HF
|
||||
:param download_some: whether to download some small example databases from h2o.ai HF
|
||||
:param download_one: whether to download one chosen example databases from h2o.ai HF
|
||||
:param download_dest: Destination for downloads
|
||||
:param n_jobs: Number of cores to use for ingesting multiple files
|
||||
:param enable_captions: Whether to enable captions on images
|
||||
:param captions_model: See generate.py
|
||||
:param pre_load_caption_model: See generate.py
|
||||
:param caption_gpu: Caption images on GPU if present
|
||||
:param enable_ocr: Whether to enable OCR on images
|
||||
:param db_type: Type of db to create. Currently only 'chroma' and 'weaviate' is supported.
|
||||
:return: None
|
||||
"""
|
||||
db = None
|
||||
|
||||
# match behavior of main() in generate.py for non-HF case
|
||||
n_gpus = get_ngpus_vis()
|
||||
if n_gpus == 0:
|
||||
if hf_embedding_model is None:
|
||||
# if no GPUs, use simpler embedding model to avoid cost in time
|
||||
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
else:
|
||||
if hf_embedding_model is None:
|
||||
# if still None, then set default
|
||||
hf_embedding_model = "hkunlp/instructor-large"
|
||||
|
||||
if download_all:
|
||||
print("Downloading all (and unzipping): %s" % all_db_zips, flush=True)
|
||||
get_some_dbs_from_hf(download_dest, db_zips=all_db_zips)
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
elif download_some:
|
||||
print(
|
||||
"Downloading some (and unzipping): %s" % some_db_zips, flush=True
|
||||
)
|
||||
get_some_dbs_from_hf(download_dest, db_zips=some_db_zips)
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
elif download_one:
|
||||
print("Downloading %s (and unzipping)" % download_one, flush=True)
|
||||
get_some_dbs_from_hf(
|
||||
download_dest, db_zips=[[download_one, "", "Unknown License"]]
|
||||
)
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
|
||||
if enable_captions and pre_load_caption_model:
|
||||
# preload, else can be too slow or if on GPU have cuda context issues
|
||||
# Inside ingestion, this will disable parallel loading of multiple other kinds of docs
|
||||
# However, if have many images, all those images will be handled more quickly by preloaded model on GPU
|
||||
from image_captions import H2OImageCaptionLoader
|
||||
|
||||
caption_loader = H2OImageCaptionLoader(
|
||||
None,
|
||||
blip_model=captions_model,
|
||||
blip_processor=captions_model,
|
||||
caption_gpu=caption_gpu,
|
||||
).load_model()
|
||||
else:
|
||||
if enable_captions:
|
||||
caption_loader = "gpu" if caption_gpu else "cpu"
|
||||
else:
|
||||
caption_loader = False
|
||||
|
||||
if verbose:
|
||||
print("Getting sources", flush=True)
|
||||
assert (
|
||||
user_path is not None or url is not None
|
||||
), "Can't have both user_path and url as None"
|
||||
if not url:
|
||||
assert os.path.isdir(user_path), (
|
||||
"user_path=%s does not exist" % user_path
|
||||
)
|
||||
sources = glob_to_db(
|
||||
user_path,
|
||||
chunk=chunk,
|
||||
chunk_size=chunk_size,
|
||||
verbose=verbose,
|
||||
fail_any_exception=fail_any_exception,
|
||||
n_jobs=n_jobs,
|
||||
url=url,
|
||||
enable_captions=enable_captions,
|
||||
captions_model=captions_model,
|
||||
caption_loader=caption_loader,
|
||||
enable_ocr=enable_ocr,
|
||||
)
|
||||
exceptions = [x for x in sources if x.metadata.get("exception")]
|
||||
print("Exceptions: %s" % exceptions, flush=True)
|
||||
sources = [x for x in sources if "exception" not in x.metadata]
|
||||
|
||||
assert len(sources) > 0, "No sources found"
|
||||
db = create_or_update_db(
|
||||
db_type,
|
||||
persist_directory,
|
||||
collection_name,
|
||||
sources,
|
||||
use_openai_embedding,
|
||||
add_if_exists,
|
||||
verbose,
|
||||
hf_embedding_model,
|
||||
)
|
||||
|
||||
assert db is not None
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,403 +0,0 @@
|
||||
"""Load Data from a MediaWiki dump xml."""
|
||||
import ast
|
||||
import glob
|
||||
import pickle
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
import os
|
||||
import bz2
|
||||
import csv
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders import MWDumpLoader
|
||||
|
||||
# path where downloaded wiki files exist, to be processed
|
||||
root_path = "/data/jon/h2o-llm"
|
||||
|
||||
|
||||
def unescape(x):
|
||||
try:
|
||||
x = ast.literal_eval(x)
|
||||
except:
|
||||
try:
|
||||
x = x.encode("ascii", "ignore").decode("unicode_escape")
|
||||
except:
|
||||
pass
|
||||
return x
|
||||
|
||||
|
||||
def get_views():
|
||||
# views = pd.read_csv('wiki_page_views_more_1000month.csv')
|
||||
views = pd.read_csv("wiki_page_views_more_5000month.csv")
|
||||
views.index = views["title"]
|
||||
views = views["views"]
|
||||
views = views.to_dict()
|
||||
views = {str(unescape(str(k))): v for k, v in views.items()}
|
||||
views2 = {k.replace("_", " "): v for k, v in views.items()}
|
||||
# views has _ but pages has " "
|
||||
views.update(views2)
|
||||
return views
|
||||
|
||||
|
||||
class MWDumpDirectLoader(MWDumpLoader):
|
||||
def __init__(
|
||||
self,
|
||||
data: str,
|
||||
encoding: Optional[str] = "utf8",
|
||||
title_words_limit=None,
|
||||
use_views=True,
|
||||
verbose=True,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self.data = data
|
||||
self.encoding = encoding
|
||||
self.title_words_limit = title_words_limit
|
||||
self.verbose = verbose
|
||||
if use_views:
|
||||
# self.views = get_views()
|
||||
# faster to use global shared values
|
||||
self.views = global_views
|
||||
else:
|
||||
self.views = None
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from file path."""
|
||||
import mwparserfromhell
|
||||
import mwxml
|
||||
|
||||
dump = mwxml.Dump.from_page_xml(self.data)
|
||||
|
||||
docs = []
|
||||
|
||||
for page in dump.pages:
|
||||
if self.views is not None and page.title not in self.views:
|
||||
if self.verbose:
|
||||
print("Skipped %s low views" % page.title, flush=True)
|
||||
continue
|
||||
for revision in page:
|
||||
if self.title_words_limit is not None:
|
||||
num_words = len(" ".join(page.title.split("_")).split(" "))
|
||||
if num_words > self.title_words_limit:
|
||||
if self.verbose:
|
||||
print("Skipped %s" % page.title, flush=True)
|
||||
continue
|
||||
if self.verbose:
|
||||
if self.views is not None:
|
||||
print(
|
||||
"Kept %s views: %s"
|
||||
% (page.title, self.views[page.title]),
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print("Kept %s" % page.title, flush=True)
|
||||
|
||||
code = mwparserfromhell.parse(revision.text)
|
||||
text = code.strip_code(
|
||||
normalize=True, collapse=True, keep_template_params=False
|
||||
)
|
||||
title_url = str(page.title).replace(" ", "_")
|
||||
metadata = dict(
|
||||
title=page.title,
|
||||
source="https://en.wikipedia.org/wiki/" + title_url,
|
||||
id=page.id,
|
||||
redirect=page.redirect,
|
||||
views=self.views[page.title]
|
||||
if self.views is not None
|
||||
else -1,
|
||||
)
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def search_index(search_term, index_filename):
|
||||
byte_flag = False
|
||||
data_length = start_byte = 0
|
||||
index_file = open(index_filename, "r")
|
||||
csv_reader = csv.reader(index_file, delimiter=":")
|
||||
for line in csv_reader:
|
||||
if not byte_flag and search_term == line[2]:
|
||||
start_byte = int(line[0])
|
||||
byte_flag = True
|
||||
elif byte_flag and int(line[0]) != start_byte:
|
||||
data_length = int(line[0]) - start_byte
|
||||
break
|
||||
index_file.close()
|
||||
return start_byte, data_length
|
||||
|
||||
|
||||
def get_start_bytes(index_filename):
|
||||
index_file = open(index_filename, "r")
|
||||
csv_reader = csv.reader(index_file, delimiter=":")
|
||||
start_bytes = set()
|
||||
for line in csv_reader:
|
||||
start_bytes.add(int(line[0]))
|
||||
index_file.close()
|
||||
return sorted(start_bytes)
|
||||
|
||||
|
||||
def get_wiki_filenames():
|
||||
# requires
|
||||
# wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2
|
||||
base_path = os.path.join(
|
||||
root_path, "enwiki-20230401-pages-articles-multistream"
|
||||
)
|
||||
index_file = "enwiki-20230401-pages-articles-multistream-index.txt"
|
||||
index_filename = os.path.join(base_path, index_file)
|
||||
wiki_filename = os.path.join(
|
||||
base_path, "enwiki-20230401-pages-articles-multistream.xml.bz2"
|
||||
)
|
||||
return index_filename, wiki_filename
|
||||
|
||||
|
||||
def get_documents_by_search_term(search_term):
|
||||
index_filename, wiki_filename = get_wiki_filenames()
|
||||
start_byte, data_length = search_index(search_term, index_filename)
|
||||
with open(wiki_filename, "rb") as wiki_file:
|
||||
wiki_file.seek(start_byte)
|
||||
data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length))
|
||||
|
||||
loader = MWDumpDirectLoader(data.decode())
|
||||
documents = loader.load()
|
||||
return documents
|
||||
|
||||
|
||||
def get_one_chunk(
|
||||
wiki_filename,
|
||||
start_byte,
|
||||
end_byte,
|
||||
return_file=True,
|
||||
title_words_limit=None,
|
||||
use_views=True,
|
||||
):
|
||||
data_length = end_byte - start_byte
|
||||
with open(wiki_filename, "rb") as wiki_file:
|
||||
wiki_file.seek(start_byte)
|
||||
data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length))
|
||||
|
||||
loader = MWDumpDirectLoader(
|
||||
data.decode(), title_words_limit=title_words_limit, use_views=use_views
|
||||
)
|
||||
documents1 = loader.load()
|
||||
if return_file:
|
||||
base_tmp = "temp_wiki"
|
||||
if not os.path.isdir(base_tmp):
|
||||
os.makedirs(base_tmp, exist_ok=True)
|
||||
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle")
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(documents1, f)
|
||||
return filename
|
||||
return documents1
|
||||
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
global_views = get_views()
|
||||
|
||||
|
||||
def get_all_documents(small_test=2, n_jobs=None, use_views=True):
|
||||
print("DO get all wiki docs: %s" % small_test, flush=True)
|
||||
index_filename, wiki_filename = get_wiki_filenames()
|
||||
start_bytes = get_start_bytes(index_filename)
|
||||
end_bytes = start_bytes[1:]
|
||||
start_bytes = start_bytes[:-1]
|
||||
|
||||
if small_test:
|
||||
start_bytes = start_bytes[:small_test]
|
||||
end_bytes = end_bytes[:small_test]
|
||||
if n_jobs is None:
|
||||
n_jobs = 5
|
||||
else:
|
||||
if n_jobs is None:
|
||||
n_jobs = os.cpu_count() // 4
|
||||
|
||||
# default loky backend leads to name space conflict problems
|
||||
return_file = True # large return from joblib hangs
|
||||
documents = Parallel(n_jobs=n_jobs, verbose=10, backend="multiprocessing")(
|
||||
delayed(get_one_chunk)(
|
||||
wiki_filename,
|
||||
start_byte,
|
||||
end_byte,
|
||||
return_file=return_file,
|
||||
use_views=use_views,
|
||||
)
|
||||
for start_byte, end_byte in zip(start_bytes, end_bytes)
|
||||
)
|
||||
if return_file:
|
||||
# then documents really are files
|
||||
files = documents.copy()
|
||||
documents = []
|
||||
for fil in files:
|
||||
with open(fil, "rb") as f:
|
||||
documents.extend(pickle.load(f))
|
||||
os.remove(fil)
|
||||
else:
|
||||
from functools import reduce
|
||||
from operator import concat
|
||||
|
||||
documents = reduce(concat, documents)
|
||||
assert isinstance(documents, list)
|
||||
|
||||
print("DONE get all wiki docs", flush=True)
|
||||
return documents
|
||||
|
||||
|
||||
def test_by_search_term():
|
||||
search_term = "Apollo"
|
||||
assert len(get_documents_by_search_term(search_term)) == 100
|
||||
|
||||
search_term = "Abstract (law)"
|
||||
assert len(get_documents_by_search_term(search_term)) == 100
|
||||
|
||||
search_term = "Artificial languages"
|
||||
assert len(get_documents_by_search_term(search_term)) == 100
|
||||
|
||||
|
||||
def test_start_bytes():
|
||||
index_filename, wiki_filename = get_wiki_filenames()
|
||||
assert len(get_start_bytes(index_filename)) == 227850
|
||||
|
||||
|
||||
def test_get_all_documents():
|
||||
small_test = 20 # 227850
|
||||
n_jobs = os.cpu_count() // 4
|
||||
|
||||
assert (
|
||||
len(
|
||||
get_all_documents(
|
||||
small_test=small_test, n_jobs=n_jobs, use_views=False
|
||||
)
|
||||
)
|
||||
== small_test * 100
|
||||
)
|
||||
|
||||
assert (
|
||||
len(
|
||||
get_all_documents(
|
||||
small_test=small_test, n_jobs=n_jobs, use_views=True
|
||||
)
|
||||
)
|
||||
== 429
|
||||
)
|
||||
|
||||
|
||||
def get_one_pageviews(fil):
|
||||
df1 = pd.read_csv(
|
||||
fil,
|
||||
sep=" ",
|
||||
header=None,
|
||||
names=["region", "title", "views", "foo"],
|
||||
quoting=csv.QUOTE_NONE,
|
||||
)
|
||||
df1.index = df1["title"]
|
||||
df1 = df1[df1["region"] == "en"]
|
||||
df1 = df1.drop("region", axis=1)
|
||||
df1 = df1.drop("foo", axis=1)
|
||||
df1 = df1.drop("title", axis=1) # already index
|
||||
|
||||
base_tmp = "temp_wiki_pageviews"
|
||||
if not os.path.isdir(base_tmp):
|
||||
os.makedirs(base_tmp, exist_ok=True)
|
||||
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.csv")
|
||||
df1.to_csv(filename, index=True)
|
||||
return filename
|
||||
|
||||
|
||||
def test_agg_pageviews(gen_files=False):
|
||||
if gen_files:
|
||||
path = os.path.join(
|
||||
root_path,
|
||||
"wiki_pageviews/dumps.wikimedia.org/other/pageviews/2023/2023-04",
|
||||
)
|
||||
files = glob.glob(os.path.join(path, "pageviews*.gz"))
|
||||
# files = files[:2] # test
|
||||
n_jobs = os.cpu_count() // 2
|
||||
csv_files = Parallel(
|
||||
n_jobs=n_jobs, verbose=10, backend="multiprocessing"
|
||||
)(delayed(get_one_pageviews)(fil) for fil in files)
|
||||
else:
|
||||
# to continue without redoing above
|
||||
csv_files = glob.glob(
|
||||
os.path.join(root_path, "temp_wiki_pageviews/*.csv")
|
||||
)
|
||||
|
||||
df_list = []
|
||||
for csv_file in csv_files:
|
||||
print(csv_file)
|
||||
df1 = pd.read_csv(csv_file)
|
||||
df_list.append(df1)
|
||||
df = pd.concat(df_list, axis=0)
|
||||
df = df.groupby("title")["views"].sum().reset_index()
|
||||
df.to_csv("wiki_page_views.csv", index=True)
|
||||
|
||||
|
||||
def test_reduce_pageview():
|
||||
filename = "wiki_page_views.csv"
|
||||
df = pd.read_csv(filename)
|
||||
df = df[df["views"] < 1e7]
|
||||
#
|
||||
plt.hist(df["views"], bins=100, log=True)
|
||||
views_avg = np.mean(df["views"])
|
||||
views_median = np.median(df["views"])
|
||||
plt.title("Views avg: %s median: %s" % (views_avg, views_median))
|
||||
plt.savefig(filename.replace(".csv", ".png"))
|
||||
plt.close()
|
||||
#
|
||||
views_limit = 5000
|
||||
df = df[df["views"] > views_limit]
|
||||
filename = "wiki_page_views_more_5000month.csv"
|
||||
df.to_csv(filename, index=True)
|
||||
#
|
||||
plt.hist(df["views"], bins=100, log=True)
|
||||
views_avg = np.mean(df["views"])
|
||||
views_median = np.median(df["views"])
|
||||
plt.title("Views avg: %s median: %s" % (views_avg, views_median))
|
||||
plt.savefig(filename.replace(".csv", ".png"))
|
||||
plt.close()
|
||||
|
||||
|
||||
@pytest.mark.skip("Only if doing full processing again, some manual steps")
|
||||
def test_do_wiki_full_all():
|
||||
# Install other requirements for wiki specific conversion:
|
||||
# pip install -r reqs_optional/requirements_optional_wikiprocessing.txt
|
||||
|
||||
# Use "Transmission" in Ubuntu to get wiki dump using torrent:
|
||||
# See: https://meta.wikimedia.org/wiki/Data_dump_torrents
|
||||
# E.g. magnet:?xt=urn:btih:b2c74af2b1531d0b63f1166d2011116f44a8fed0&dn=enwiki-20230401-pages-articles-multistream.xml.bz2&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337
|
||||
|
||||
# Get index
|
||||
os.system(
|
||||
"wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2"
|
||||
)
|
||||
|
||||
# Test that can use LangChain to get docs from subset of wiki as sampled out of full wiki directly using bzip multistream
|
||||
test_get_all_documents()
|
||||
|
||||
# Check can search wiki multistream
|
||||
test_by_search_term()
|
||||
|
||||
# Test can get all start bytes in index
|
||||
test_start_bytes()
|
||||
|
||||
# Get page views, e.g. for entire month of April 2023
|
||||
os.system(
|
||||
"wget -b -m -k -o wget.log -e robots=off https://dumps.wikimedia.org/other/pageviews/2023/2023-04/"
|
||||
)
|
||||
|
||||
# Aggregate page views from many files into single file
|
||||
test_agg_pageviews(gen_files=True)
|
||||
|
||||
# Reduce page views to some limit, so processing of full wiki is not too large
|
||||
test_reduce_pageview()
|
||||
|
||||
# Start generate.py with requesting wiki_full in prep. This will use page views as referenced in get_views.
|
||||
# Note get_views as global() function done once is required to avoid very slow processing
|
||||
# WARNING: Requires alot of memory to handle, used up to 300GB system RAM at peak
|
||||
"""
|
||||
python generate.py --langchain_mode='wiki_full' --visible_langchain_modes="['wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']" &> lc_out.log
|
||||
"""
|
||||
@@ -1,121 +0,0 @@
|
||||
import torch
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
from enums import PromptType
|
||||
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
def __init__(
|
||||
self, stops=[], encounters=[], device="cuda", model_max_length=None
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
len(stops) % len(encounters) == 0
|
||||
), "Number of stops and encounters must match"
|
||||
self.encounters = encounters
|
||||
self.stops = [stop.to(device) for stop in stops]
|
||||
self.num_stops = [0] * len(stops)
|
||||
self.model_max_length = model_max_length
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stopi, stop in enumerate(self.stops):
|
||||
if torch.all((stop == input_ids[0][-len(stop) :])).item():
|
||||
self.num_stops[stopi] += 1
|
||||
if (
|
||||
self.num_stops[stopi]
|
||||
>= self.encounters[stopi % len(self.encounters)]
|
||||
):
|
||||
# print("Stopped", flush=True)
|
||||
return True
|
||||
if (
|
||||
self.model_max_length is not None
|
||||
and input_ids[0].shape[0] >= self.model_max_length
|
||||
):
|
||||
# critical limit
|
||||
return True
|
||||
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
||||
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
||||
return False
|
||||
|
||||
|
||||
def get_stopping(
|
||||
prompt_type,
|
||||
prompt_dict,
|
||||
tokenizer,
|
||||
device,
|
||||
human="<human>:",
|
||||
bot="<bot>:",
|
||||
model_max_length=None,
|
||||
):
|
||||
# FIXME: prompt_dict unused currently
|
||||
if prompt_type in [
|
||||
PromptType.human_bot.name,
|
||||
PromptType.instruct_vicuna.name,
|
||||
PromptType.instruct_with_end.name,
|
||||
]:
|
||||
if prompt_type == PromptType.human_bot.name:
|
||||
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
||||
# stopping only starts once output is beyond prompt
|
||||
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
||||
stop_words = [human, bot, "\n" + human, "\n" + bot]
|
||||
encounters = [1, 2]
|
||||
elif prompt_type == PromptType.instruct_vicuna.name:
|
||||
# even below is not enough, generic strings and many ways to encode
|
||||
stop_words = [
|
||||
"### Human:",
|
||||
"""
|
||||
### Human:""",
|
||||
"""
|
||||
### Human:
|
||||
""",
|
||||
"### Assistant:",
|
||||
"""
|
||||
### Assistant:""",
|
||||
"""
|
||||
### Assistant:
|
||||
""",
|
||||
]
|
||||
encounters = [1, 2]
|
||||
else:
|
||||
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
||||
stop_words = ["### End"]
|
||||
encounters = [1]
|
||||
stop_words_ids = [
|
||||
tokenizer(stop_word, return_tensors="pt")["input_ids"].squeeze()
|
||||
for stop_word in stop_words
|
||||
]
|
||||
# handle single token case
|
||||
stop_words_ids = [
|
||||
x if len(x.shape) > 0 else torch.tensor([x])
|
||||
for x in stop_words_ids
|
||||
]
|
||||
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
||||
# avoid padding in front of tokens
|
||||
if (
|
||||
tokenizer._pad_token
|
||||
): # use hidden variable to avoid annoying properly logger bug
|
||||
stop_words_ids = [
|
||||
x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x
|
||||
for x in stop_words_ids
|
||||
]
|
||||
# handle fake \n added
|
||||
stop_words_ids = [
|
||||
x[1:] if y[0] == "\n" else x
|
||||
for x, y in zip(stop_words_ids, stop_words)
|
||||
]
|
||||
# build stopper
|
||||
stopping_criteria = StoppingCriteriaList(
|
||||
[
|
||||
StoppingCriteriaSub(
|
||||
stops=stop_words_ids,
|
||||
encounters=encounters,
|
||||
device=device,
|
||||
model_max_length=model_max_length,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
return stopping_criteria
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,69 +0,0 @@
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
import time
|
||||
import queue
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class StreamingGradioCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: Optional[float] = None, block=True):
|
||||
super().__init__()
|
||||
self.text_queue = queue.SimpleQueue()
|
||||
self.stop_signal = None
|
||||
self.do_stop = False
|
||||
self.timeout = timeout
|
||||
self.block = block
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running. Clean the queue."""
|
||||
while not self.text_queue.empty():
|
||||
try:
|
||||
self.text_queue.get(block=False)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
self.text_queue.put(token)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.text_queue.put(self.stop_signal)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.text_queue.put(self.stop_signal)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
try:
|
||||
value = (
|
||||
self.stop_signal
|
||||
) # value looks unused in pycharm, not true
|
||||
if self.do_stop:
|
||||
print("hit stop", flush=True)
|
||||
# could raise or break, maybe best to raise and make parent see if any exception in thread
|
||||
raise StopIteration()
|
||||
# break
|
||||
value = self.text_queue.get(
|
||||
block=self.block, timeout=self.timeout
|
||||
)
|
||||
break
|
||||
except queue.Empty:
|
||||
time.sleep(0.01)
|
||||
if value == self.stop_signal:
|
||||
raise StopIteration()
|
||||
else:
|
||||
return value
|
||||
@@ -1,442 +0,0 @@
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
from argparse import RawTextHelpFormatter
|
||||
import re, gc
|
||||
|
||||
"""
|
||||
This script can be used as a standalone utility to convert IRs to dynamic + combine them.
|
||||
Following are the various ways this script can be used :-
|
||||
a. To convert a single Linalg IR to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO FIRST IR>
|
||||
b. To convert two Linalg IRs to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>
|
||||
c. To combine two Linalg IRs into one:
|
||||
--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
d. To convert both IRs into dynamic as well as combine the IRs:
|
||||
--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
|
||||
NOTE: For dynamic you'll also need to provide the following set of flags:-
|
||||
i. For First Llama : --dynamic_input_size (DEFAULT: 19)
|
||||
ii. For Second Llama: --model_name (DEFAULT: llama2_7b)
|
||||
--precision (DEFAULT: 'int4')
|
||||
You may use --save_dynamic to also save the dynamic IR in option d above.
|
||||
Else for option a. and b. the dynamic IR(s) will get saved by default.
|
||||
"""
|
||||
|
||||
|
||||
def combine_mlir_scripts(
|
||||
first_vicuna_mlir,
|
||||
second_vicuna_mlir,
|
||||
output_name,
|
||||
return_ir=True,
|
||||
):
|
||||
print(f"[DEBUG] combining first and second mlir")
|
||||
print(f"[DEBUG] output_name = {output_name}")
|
||||
maps1 = []
|
||||
maps2 = []
|
||||
constants = set()
|
||||
f1 = []
|
||||
f2 = []
|
||||
|
||||
print(f"[DEBUG] processing first vicuna mlir")
|
||||
first_vicuna_mlir = first_vicuna_mlir.splitlines()
|
||||
while first_vicuna_mlir:
|
||||
line = first_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps1.append(line)
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "first_vicuna_forward", line)
|
||||
f1.append(line)
|
||||
f1 = f1[:-1]
|
||||
del first_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps1):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
|
||||
maps1[i] = map_line
|
||||
f1 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
|
||||
for func_line in f1
|
||||
]
|
||||
|
||||
print(f"[DEBUG] processing second vicuna mlir")
|
||||
second_vicuna_mlir = second_vicuna_mlir.splitlines()
|
||||
while second_vicuna_mlir:
|
||||
line = second_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps2.append(line)
|
||||
elif "global_seed" in line:
|
||||
continue
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "second_vicuna_forward", line)
|
||||
f2.append(line)
|
||||
f2 = f2[:-1]
|
||||
del second_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps2):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
|
||||
maps2[i] = map_line
|
||||
f2 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
|
||||
for func_line in f2
|
||||
]
|
||||
|
||||
module_start = 'module attributes {torch.debug_module_name = "_lambda"} {'
|
||||
module_end = "}"
|
||||
|
||||
global_vars = []
|
||||
vnames = []
|
||||
global_var_loading1 = []
|
||||
global_var_loading2 = []
|
||||
|
||||
print(f"[DEBUG] processing constants")
|
||||
counter = 0
|
||||
constants = list(constants)
|
||||
while constants:
|
||||
constant = constants.pop(0)
|
||||
vname, vbody = constant.split("=")
|
||||
vname = re.sub("%", "", vname)
|
||||
vname = vname.strip()
|
||||
vbody = re.sub("arith.constant", "", vbody)
|
||||
vbody = vbody.strip()
|
||||
if len(vbody.split(":")) < 2:
|
||||
print(constant)
|
||||
vdtype = vbody.split(":")[-1].strip()
|
||||
fixed_vdtype = vdtype
|
||||
if "c1_i64" in vname:
|
||||
print(constant)
|
||||
counter += 1
|
||||
if counter == 2:
|
||||
counter = 0
|
||||
print("detected duplicate")
|
||||
continue
|
||||
vnames.append(vname)
|
||||
if "true" not in vname:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
else:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : i1"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
|
||||
new_f1, new_f2 = [], []
|
||||
|
||||
print(f"[DEBUG] processing f1")
|
||||
for line in f1:
|
||||
if "func.func" in line:
|
||||
new_f1.append(line)
|
||||
for global_var in global_var_loading1:
|
||||
new_f1.append(global_var)
|
||||
else:
|
||||
new_f1.append(line)
|
||||
|
||||
print(f"[DEBUG] processing f2")
|
||||
for line in f2:
|
||||
if "func.func" in line:
|
||||
new_f2.append(line)
|
||||
for global_var in global_var_loading2:
|
||||
if (
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
in global_var
|
||||
):
|
||||
print(global_var)
|
||||
new_f2.append(global_var)
|
||||
else:
|
||||
new_f2.append(line)
|
||||
|
||||
f1 = new_f1
|
||||
f2 = new_f2
|
||||
|
||||
del new_f1
|
||||
del new_f2
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
[
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
|
||||
for x in [maps1, maps2, global_vars, f1, f2]
|
||||
]
|
||||
)
|
||||
|
||||
# doing it this way rather than assembling the whole string
|
||||
# to prevent OOM with 64GiB RAM when encoding the file.
|
||||
|
||||
print(f"[DEBUG] Saving mlir to {output_name}")
|
||||
with open(output_name, "w+") as f_:
|
||||
f_.writelines(line + "\n" for line in maps1)
|
||||
f_.writelines(line + "\n" for line in maps2)
|
||||
f_.writelines(line + "\n" for line in [module_start])
|
||||
f_.writelines(line + "\n" for line in global_vars)
|
||||
f_.writelines(line + "\n" for line in f1)
|
||||
f_.writelines(line + "\n" for line in f2)
|
||||
f_.writelines(line + "\n" for line in [module_end])
|
||||
|
||||
del maps1
|
||||
del maps2
|
||||
del module_start
|
||||
del global_vars
|
||||
del f1
|
||||
del f2
|
||||
del module_end
|
||||
gc.collect()
|
||||
|
||||
if return_ir:
|
||||
print(f"[DEBUG] Reading combined mlir back in")
|
||||
with open(output_name, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def write_in_dynamic_inputs0(module, dynamic_input_size):
|
||||
print("[DEBUG] writing dynamic inputs to first vicuna")
|
||||
# Current solution for ensuring mlir files support dynamic inputs
|
||||
# TODO: find a more elegant way to implement this
|
||||
new_lines = []
|
||||
module = module.splitlines()
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
|
||||
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
|
||||
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
|
||||
continue
|
||||
|
||||
new_lines.append(line)
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def write_in_dynamic_inputs1(module, model_name, precision):
|
||||
print("[DEBUG] writing dynamic inputs to second vicuna")
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "c19_i64" in line:
|
||||
line = re.sub("c19_i64", "dim_i64", line)
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
if "x20x" in line or "<20x" in line:
|
||||
line = re.sub("20x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
|
||||
if " 20," in line:
|
||||
line = re.sub(" 20,", " %dimp1,", line)
|
||||
return line
|
||||
|
||||
module = module.splitlines()
|
||||
new_lines = []
|
||||
|
||||
# Using a while loop and the pop method to avoid creating a copy of module
|
||||
if "llama2_13b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x40x?x128x"
|
||||
elif "llama2_70b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x8x?x128x"
|
||||
else:
|
||||
pkv_tensor_shape = "tensor<1x32x?x128x"
|
||||
if precision in ["fp16", "int4", "int8"]:
|
||||
pkv_tensor_shape += "f16>"
|
||||
else:
|
||||
pkv_tensor_shape += "f32>"
|
||||
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append(
|
||||
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
)
|
||||
continue
|
||||
if "%c2 = arith.constant 2 : index" in line:
|
||||
continue
|
||||
if "%c20_i64 = arith.constant 20 : i64" in line:
|
||||
new_lines.append("%c1_i64 = arith.constant 1 : i64")
|
||||
new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
|
||||
new_lines.append(
|
||||
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
|
||||
)
|
||||
continue
|
||||
line = remove_constant_dim(line)
|
||||
new_lines.append(line)
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def save_dynamic_ir(ir_to_save, output_file):
|
||||
if not ir_to_save:
|
||||
return
|
||||
# We only get string output from the dynamic conversion utility.
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
print(ir_to_save)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="llama ir utility",
|
||||
description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n"
|
||||
+ "\tFollowing are the various ways this script can be used :-\n"
|
||||
+ "\t\ta. To convert a single Linalg IR to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO FIRST IR>\n"
|
||||
+ "\t\tb. To convert two Linalg IRs to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\tc. To combine two Linalg IRs into one:\n"
|
||||
+ "\t\t\t--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n"
|
||||
+ "\t\t\t--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n\n"
|
||||
+ "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n"
|
||||
+ "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n"
|
||||
+ "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n"
|
||||
+ "\t\t\t--precision (DEFAULT: 'int4')\n"
|
||||
+ "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n"
|
||||
+ "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
"-p",
|
||||
default="int4",
|
||||
choices=["fp32", "fp16", "int8", "int4"],
|
||||
help="Precision of the concerned IR",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="llama2_7b",
|
||||
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
|
||||
help="Specify which model to run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first_ir_path",
|
||||
default=None,
|
||||
help="path to first llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--second_ir_path",
|
||||
default=None,
|
||||
help="path to second llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic_input_size",
|
||||
type=int,
|
||||
default=19,
|
||||
help="Specify the static input size to replace with dynamic dim.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Save the individual IR(s) after converting to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--combine",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
dynamic = args.dynamic
|
||||
combine = args.combine
|
||||
assert (
|
||||
dynamic or combine
|
||||
), "neither `dynamic` nor `combine` flag is turned on"
|
||||
first_ir_path = args.first_ir_path
|
||||
second_ir_path = args.second_ir_path
|
||||
assert first_ir_path or second_ir_path, "no input ir has been provided"
|
||||
if combine:
|
||||
assert (
|
||||
first_ir_path and second_ir_path
|
||||
), "you will need to provide both IRs to combine"
|
||||
precision = args.precision
|
||||
model_name = args.model_name
|
||||
dynamic_input_size = args.dynamic_input_size
|
||||
save_dynamic = args.save_dynamic
|
||||
|
||||
print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}")
|
||||
print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}")
|
||||
|
||||
if dynamic and not combine:
|
||||
save_dynamic = True
|
||||
|
||||
first_ir = None
|
||||
first_dynamic_ir_name = None
|
||||
second_ir = None
|
||||
second_dynamic_ir_name = None
|
||||
if first_ir_path:
|
||||
first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic"
|
||||
with open(first_ir_path, "r") as f:
|
||||
first_ir = f.read()
|
||||
if second_ir_path:
|
||||
second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic"
|
||||
with open(second_ir_path, "r") as f:
|
||||
second_ir = f.read()
|
||||
if dynamic:
|
||||
first_ir = (
|
||||
write_in_dynamic_inputs0(first_ir, dynamic_input_size)
|
||||
if first_ir
|
||||
else None
|
||||
)
|
||||
second_ir = (
|
||||
write_in_dynamic_inputs1(second_ir, model_name, precision)
|
||||
if second_ir
|
||||
else None
|
||||
)
|
||||
if save_dynamic:
|
||||
save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir")
|
||||
save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir")
|
||||
|
||||
if combine:
|
||||
combine_mlir_scripts(
|
||||
first_ir,
|
||||
second_ir,
|
||||
f"{model_name}_{precision}.mlir",
|
||||
return_ir=False,
|
||||
)
|
||||
@@ -1,211 +0,0 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
StoppingCriteria,
|
||||
)
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def shouldStop(tokens):
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 256
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
def compile_stableLM(
|
||||
model,
|
||||
model_inputs,
|
||||
model_name,
|
||||
model_vmfb_name,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug=False,
|
||||
):
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
vmfb_path = (
|
||||
Path(model_name + f"_{device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
shark_module = get_vmfb_from_path(
|
||||
vmfb_path, device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
if shark_module is not None:
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*model_inputs],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(model_name + ".mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir")
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
class StableLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
|
||||
|
||||
# Initialize a StopOnTokens object
|
||||
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
|
||||
|
||||
def get_tokenizer():
|
||||
model_path = "stabilityai/stablelm-tuned-alpha-3b"
|
||||
tok = AutoTokenizer.from_pretrained(model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
print("Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
|
||||
# sharkStableLM = compile_stableLM
|
||||
# (
|
||||
# None,
|
||||
# tuple([input_ids, attention_mask]),
|
||||
# "stableLM_linalg_f32_seqLen256",
|
||||
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
|
||||
# )
|
||||
def generate(
|
||||
new_text,
|
||||
max_new_tokens,
|
||||
sharkStableLM,
|
||||
tokenizer=None,
|
||||
):
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer()
|
||||
# Construct the input message string for the model by
|
||||
# concatenating the current system message and conversation history
|
||||
# Tokenize the messages string
|
||||
# sharkStableLM = compile_stableLM
|
||||
# (
|
||||
# None,
|
||||
# tuple([input_ids, attention_mask]),
|
||||
# "stableLM_linalg_f32_seqLen256",
|
||||
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
|
||||
# )
|
||||
words_list = []
|
||||
for i in range(max_new_tokens):
|
||||
# numWords = len(new_text.split())
|
||||
# if(numWords>220):
|
||||
# break
|
||||
params = {
|
||||
"new_text": new_text,
|
||||
}
|
||||
generated_token_op = generate_new_token(
|
||||
sharkStableLM, tokenizer, params
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
if stop_generation:
|
||||
break
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text = new_text + detok
|
||||
return words_list
|
||||
|
||||
|
||||
def generate_new_token(shark_model, tokenizer, params):
|
||||
new_text = params["new_text"]
|
||||
model_inputs = tokenizer(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
output = shark_model(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
|
||||
detok = tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,94 +0,0 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += copy_metadata('huggingface-hub')
|
||||
datas += copy_metadata('sentencepiece')
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('py-cpuinfo')
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/vicuna.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_llama_cli',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
@@ -1,22 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class FalconModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
}
|
||||
output = self.model(
|
||||
**input_dict,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)[0]
|
||||
return output[:, -1, :]
|
||||
@@ -1,675 +0,0 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class WordEmbeddingsLayer(torch.nn.Module):
|
||||
def __init__(self, word_embedding_layer):
|
||||
super().__init__()
|
||||
self.model = word_embedding_layer
|
||||
|
||||
def forward(self, input_ids):
|
||||
output = self.model.forward(input=input_ids)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledWordEmbeddingsLayer(torch.nn.Module):
|
||||
def __init__(self, compiled_word_embedding_layer):
|
||||
super().__init__()
|
||||
self.model = compiled_word_embedding_layer
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids = input_ids.detach().numpy()
|
||||
new_input_ids = self.model("forward", input_ids)
|
||||
new_input_ids = new_input_ids.reshape(
|
||||
[1, new_input_ids.shape[0], new_input_ids.shape[1]]
|
||||
)
|
||||
return torch.tensor(new_input_ids)
|
||||
|
||||
|
||||
class LNFEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, ln_f):
|
||||
super().__init__()
|
||||
self.model = ln_f
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model.forward(input=hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledLNFEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, ln_f):
|
||||
super().__init__()
|
||||
self.model = ln_f
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach().numpy()
|
||||
new_hidden_states = self.model("forward", (hidden_states,))
|
||||
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class LMHeadEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, embedding_layer):
|
||||
super().__init__()
|
||||
self.model = embedding_layer
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model.forward(input=hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, lm_head):
|
||||
super().__init__()
|
||||
self.model = lm_head
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach().numpy()
|
||||
new_hidden_states = self.model("forward", (hidden_states,))
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class FourWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
self.falcon_variant = falcon_variant
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
new_pkvs = []
|
||||
for layer in self.model:
|
||||
outputs = layer(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision, model
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.to(torch.float32).detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class TwoWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
self.falcon_variant = falcon_variant
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
new_pkvs = []
|
||||
for layer in self.model:
|
||||
outputs = layer(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
(new_pkv200, new_pkv201),
|
||||
(new_pkv210, new_pkv211),
|
||||
(new_pkv220, new_pkv221),
|
||||
(new_pkv230, new_pkv231),
|
||||
(new_pkv240, new_pkv241),
|
||||
(new_pkv250, new_pkv251),
|
||||
(new_pkv260, new_pkv261),
|
||||
(new_pkv270, new_pkv271),
|
||||
(new_pkv280, new_pkv281),
|
||||
(new_pkv290, new_pkv291),
|
||||
(new_pkv300, new_pkv301),
|
||||
(new_pkv310, new_pkv311),
|
||||
(new_pkv320, new_pkv321),
|
||||
(new_pkv330, new_pkv331),
|
||||
(new_pkv340, new_pkv341),
|
||||
(new_pkv350, new_pkv351),
|
||||
(new_pkv360, new_pkv361),
|
||||
(new_pkv370, new_pkv371),
|
||||
(new_pkv380, new_pkv381),
|
||||
(new_pkv390, new_pkv391),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
new_pkv200,
|
||||
new_pkv201,
|
||||
new_pkv210,
|
||||
new_pkv211,
|
||||
new_pkv220,
|
||||
new_pkv221,
|
||||
new_pkv230,
|
||||
new_pkv231,
|
||||
new_pkv240,
|
||||
new_pkv241,
|
||||
new_pkv250,
|
||||
new_pkv251,
|
||||
new_pkv260,
|
||||
new_pkv261,
|
||||
new_pkv270,
|
||||
new_pkv271,
|
||||
new_pkv280,
|
||||
new_pkv281,
|
||||
new_pkv290,
|
||||
new_pkv291,
|
||||
new_pkv300,
|
||||
new_pkv301,
|
||||
new_pkv310,
|
||||
new_pkv311,
|
||||
new_pkv320,
|
||||
new_pkv321,
|
||||
new_pkv330,
|
||||
new_pkv331,
|
||||
new_pkv340,
|
||||
new_pkv341,
|
||||
new_pkv350,
|
||||
new_pkv351,
|
||||
new_pkv360,
|
||||
new_pkv361,
|
||||
new_pkv370,
|
||||
new_pkv371,
|
||||
new_pkv380,
|
||||
new_pkv381,
|
||||
new_pkv390,
|
||||
new_pkv391,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision, model
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.to(torch.float32).detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[41]),
|
||||
torch.tensor(output[42]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[43]),
|
||||
torch.tensor(output[44]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[45]),
|
||||
torch.tensor(output[46]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[47]),
|
||||
torch.tensor(output[48]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[49]),
|
||||
torch.tensor(output[50]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[51]),
|
||||
torch.tensor(output[52]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[53]),
|
||||
torch.tensor(output[54]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[55]),
|
||||
torch.tensor(output[56]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[57]),
|
||||
torch.tensor(output[58]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[59]),
|
||||
torch.tensor(output[60]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[61]),
|
||||
torch.tensor(output[62]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[63]),
|
||||
torch.tensor(output[64]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[65]),
|
||||
torch.tensor(output[66]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[67]),
|
||||
torch.tensor(output[68]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[69]),
|
||||
torch.tensor(output[70]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[71]),
|
||||
torch.tensor(output[72]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[73]),
|
||||
torch.tensor(output[74]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[75]),
|
||||
torch.tensor(output[76]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[77]),
|
||||
torch.tensor(output[78]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[79]),
|
||||
torch.tensor(output[80]),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class ShardedFalconModel:
|
||||
def __init__(self, model, layers, word_embeddings, ln_f, lm_head):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.model.transformer.h = torch.nn.modules.container.ModuleList(
|
||||
layers
|
||||
)
|
||||
self.model.transformer.word_embeddings = word_embeddings
|
||||
self.model.transformer.ln_f = ln_f
|
||||
self.model.lm_head = lm_head
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
):
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
).logits[:, -1, :]
|
||||
@@ -1,503 +0,0 @@
|
||||
import torch
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class VisionModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ln_vision,
|
||||
visual_encoder,
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_vision = ln_vision
|
||||
self.visual_encoder = visual_encoder
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Vision Model applying weight quantization to ln_vision")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.ln_vision,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
print(
|
||||
"Vision Model applying weight quantization to visual_encoder"
|
||||
)
|
||||
quantize_model(
|
||||
self.visual_encoder,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, image):
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image))
|
||||
return image_embeds
|
||||
|
||||
|
||||
class QformerBertModel(torch.nn.Module):
|
||||
def __init__(self, qformer_bert):
|
||||
super().__init__()
|
||||
self.qformer_bert = qformer_bert
|
||||
|
||||
def forward(self, query_tokens, image_embeds, image_atts):
|
||||
query_output = self.qformer_bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
return query_output.last_hidden_state
|
||||
|
||||
|
||||
class FirstLlamaModel(torch.nn.Module):
|
||||
def __init__(self, model, precision="fp32", weight_group_size=128):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
print("SHARK: Loading LLAMA Done")
|
||||
if precision in ["int4", "int8"]:
|
||||
print("First Llama applying weight quantization")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, inputs_embeds, position_ids, attention_mask):
|
||||
print("************************************")
|
||||
print(
|
||||
"inputs_embeds: ",
|
||||
inputs_embeds.shape,
|
||||
" dtype: ",
|
||||
inputs_embeds.dtype,
|
||||
)
|
||||
print(
|
||||
"position_ids: ",
|
||||
position_ids.shape,
|
||||
" dtype: ",
|
||||
position_ids.dtype,
|
||||
)
|
||||
print(
|
||||
"attention_mask: ",
|
||||
attention_mask.shape,
|
||||
" dtype: ",
|
||||
attention_mask.dtype,
|
||||
)
|
||||
print("************************************")
|
||||
config = {
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(
|
||||
**config,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(output.logits)
|
||||
temp_past_key_values = output.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondLlamaModel(torch.nn.Module):
|
||||
def __init__(self, model, precision="fp32", weight_group_size=128):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
print("SHARK: Loading LLAMA Done")
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Second Llama applying weight quantization")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
):
|
||||
print("************************************")
|
||||
print("input_ids: ", input_ids.shape, " dtype: ", input_ids.dtype)
|
||||
print(
|
||||
"position_ids: ",
|
||||
position_ids.shape,
|
||||
" dtype: ",
|
||||
position_ids.dtype,
|
||||
)
|
||||
print(
|
||||
"attention_mask: ",
|
||||
attention_mask.shape,
|
||||
" dtype: ",
|
||||
attention_mask.dtype,
|
||||
)
|
||||
print("past_key_values: ", i1.shape, i2.shape, i63.shape, i64.shape)
|
||||
print("past_key_values dtype: ", i1.dtype)
|
||||
print("************************************")
|
||||
config = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
),
|
||||
"use_cache": True,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(
|
||||
**config,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(output.logits)
|
||||
temp_past_key_values = output.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
|
||||
skip_next: bool = False
|
||||
conv_id: Any = None
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
conv_id=self.conv_id,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
"conv_id": self.conv_id,
|
||||
}
|
||||
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
def __init__(self, stops=[], encounters=1):
|
||||
super().__init__()
|
||||
self.stops = stops
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
for stop in self.stops:
|
||||
if torch.all((stop == input_ids[0][-len(stop) :])).item():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
CONV_VISION = Conversation(
|
||||
system="Give the following image: <Img>ImageContent</Img>. "
|
||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=[],
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
@@ -1,15 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class StableLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
@@ -1,876 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import iree.runtime
|
||||
import itertools
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import TensorPlaceholder
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
LlamaPreTrainedModel,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
FirstVicunaLayer,
|
||||
SecondVicunaLayer,
|
||||
CompiledVicunaLayer,
|
||||
ShardedVicunaModel,
|
||||
LMHead,
|
||||
LMHeadCompiled,
|
||||
VicunaEmbedding,
|
||||
VicunaEmbeddingCompiled,
|
||||
VicunaNorm,
|
||||
VicunaNormCompiled,
|
||||
)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna7B,
|
||||
)
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import get_f16_inputs
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaRMSNorm,
|
||||
_make_causal_mask,
|
||||
_expand_mask,
|
||||
)
|
||||
from torch import nn
|
||||
from time import time
|
||||
|
||||
|
||||
class LlamaModel(LlamaPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: LlamaConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size, config.hidden_size, self.padding_idx
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LlamaDecoderLayer(config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
).to(inputs_embeds.device)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
t1 = time()
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = (
|
||||
use_cache if use_cache is not None else self.config.use_cache
|
||||
)
|
||||
|
||||
return_dict = (
|
||||
return_dict
|
||||
if return_dict is not None
|
||||
else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = (
|
||||
seq_length_with_past + past_key_values_length
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
device = (
|
||||
input_ids.device
|
||||
if input_ids is not None
|
||||
else inputs_embeds.device
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.compressedlayers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (
|
||||
past_key_values[8 * idx : 8 * (idx + 1)]
|
||||
if past_key_values is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer.forward(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[1:],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
try:
|
||||
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
|
||||
except:
|
||||
_ = 10
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
next_cache = tuple(itertools.chain.from_iterable(next_cache))
|
||||
print(f"Token generated in {time() - t1} seconds")
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class EightLayerLayerSV(torch.nn.Module):
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
assert len(layers) == 8
|
||||
self.layers = layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
):
|
||||
pkvs = [
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
]
|
||||
new_pkvs = []
|
||||
for layer, pkv in zip(self.layers, pkvs):
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
pkv[0],
|
||||
pkv[1],
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
return (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
|
||||
|
||||
class EightLayerLayerFV(torch.nn.Module):
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
assert len(layers) == 8
|
||||
self.layers = layers
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
new_pkvs = []
|
||||
for layer in self.layers:
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=None,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
return (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
|
||||
|
||||
class CompiledEightLayerLayerSV(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
(
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
) = past_key_value
|
||||
pkv00 = pkv00.detatch()
|
||||
pkv01 = pkv01.detatch()
|
||||
pkv10 = pkv10.detatch()
|
||||
pkv11 = pkv11.detatch()
|
||||
pkv20 = pkv20.detatch()
|
||||
pkv21 = pkv21.detatch()
|
||||
pkv30 = pkv30.detatch()
|
||||
pkv31 = pkv31.detatch()
|
||||
pkv40 = pkv40.detatch()
|
||||
pkv41 = pkv41.detatch()
|
||||
pkv50 = pkv50.detatch()
|
||||
pkv51 = pkv51.detatch()
|
||||
pkv60 = pkv60.detatch()
|
||||
pkv61 = pkv61.detatch()
|
||||
pkv70 = pkv70.detatch()
|
||||
pkv71 = pkv71.detatch()
|
||||
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
return (
|
||||
output[0],
|
||||
(output[1][0], output[1][1]),
|
||||
(output[2][0], output[2][1]),
|
||||
(output[3][0], output[3][1]),
|
||||
(output[4][0], output[4][1]),
|
||||
(output[5][0], output[5][1]),
|
||||
(output[6][0], output[6][1]),
|
||||
(output[7][0], output[7][1]),
|
||||
(output[8][0], output[8][1]),
|
||||
)
|
||||
|
||||
|
||||
def forward_compressed(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = (
|
||||
input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.compressedlayers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (
|
||||
past_key_values[8 * idx : 8 * (idx + 1)]
|
||||
if past_key_values is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (
|
||||
layer_outputs[2 if output_attentions else 1],
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class CompiledEightLayerLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
t2 = time()
|
||||
if past_key_value is None:
|
||||
try:
|
||||
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
|
||||
except:
|
||||
pass
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
t1 = time()
|
||||
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
(hidden_states, attention_mask, position_ids),
|
||||
send_to_host=False,
|
||||
)
|
||||
output2 = (
|
||||
output[0],
|
||||
(
|
||||
output[1],
|
||||
output[2],
|
||||
),
|
||||
(
|
||||
output[3],
|
||||
output[4],
|
||||
),
|
||||
(
|
||||
output[5],
|
||||
output[6],
|
||||
),
|
||||
(
|
||||
output[7],
|
||||
output[8],
|
||||
),
|
||||
(
|
||||
output[9],
|
||||
output[10],
|
||||
),
|
||||
(
|
||||
output[11],
|
||||
output[12],
|
||||
),
|
||||
(
|
||||
output[13],
|
||||
output[14],
|
||||
),
|
||||
(
|
||||
output[15],
|
||||
output[16],
|
||||
),
|
||||
)
|
||||
return output2
|
||||
else:
|
||||
(
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
) = past_key_value
|
||||
|
||||
try:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv00 = pkv00.detach()
|
||||
pkv01 = pkv01.detach()
|
||||
pkv10 = pkv10.detach()
|
||||
pkv11 = pkv11.detach()
|
||||
pkv20 = pkv20.detach()
|
||||
pkv21 = pkv21.detach()
|
||||
pkv30 = pkv30.detach()
|
||||
pkv31 = pkv31.detach()
|
||||
pkv40 = pkv40.detach()
|
||||
pkv41 = pkv41.detach()
|
||||
pkv50 = pkv50.detach()
|
||||
pkv51 = pkv51.detach()
|
||||
pkv60 = pkv60.detach()
|
||||
pkv61 = pkv61.detach()
|
||||
pkv70 = pkv70.detach()
|
||||
pkv71 = pkv71.detach()
|
||||
except:
|
||||
x = 10
|
||||
|
||||
t1 = time()
|
||||
if type(hidden_states) == iree.runtime.array_interop.DeviceArray:
|
||||
hidden_states = np.array(hidden_states, hidden_states.dtype)
|
||||
hidden_states = torch.tensor(hidden_states)
|
||||
hidden_states = hidden_states.detach()
|
||||
|
||||
output = self.model(
|
||||
"second_vicuna_forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
print(f"{time() - t1}")
|
||||
del pkv00
|
||||
del pkv01
|
||||
del pkv10
|
||||
del pkv11
|
||||
del pkv20
|
||||
del pkv21
|
||||
del pkv30
|
||||
del pkv31
|
||||
del pkv40
|
||||
del pkv41
|
||||
del pkv50
|
||||
del pkv51
|
||||
del pkv60
|
||||
del pkv61
|
||||
del pkv70
|
||||
del pkv71
|
||||
output2 = (
|
||||
output[0],
|
||||
(
|
||||
output[1],
|
||||
output[2],
|
||||
),
|
||||
(
|
||||
output[3],
|
||||
output[4],
|
||||
),
|
||||
(
|
||||
output[5],
|
||||
output[6],
|
||||
),
|
||||
(
|
||||
output[7],
|
||||
output[8],
|
||||
),
|
||||
(
|
||||
output[9],
|
||||
output[10],
|
||||
),
|
||||
(
|
||||
output[11],
|
||||
output[12],
|
||||
),
|
||||
(
|
||||
output[13],
|
||||
output[14],
|
||||
),
|
||||
(
|
||||
output[15],
|
||||
output[16],
|
||||
),
|
||||
)
|
||||
return output2
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,247 +0,0 @@
|
||||
import torch
|
||||
import time
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class SecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers, lmhead, embedding, norm):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers = layers
|
||||
self.norm = norm
|
||||
self.embedding = embedding
|
||||
self.lmhead = lmhead
|
||||
self.model.model.norm = self.norm
|
||||
self.model.model.embed_tokens = self.embedding
|
||||
self.model.lm_head = self.lmhead
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
is_first=True,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class LMHead(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class LMHeadCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states_sample = hidden_states.detach()
|
||||
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VicunaNorm(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaNormCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
try:
|
||||
hidden_states.detach()
|
||||
except:
|
||||
pass
|
||||
output = self.model("forward", (hidden_states,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VicunaEmbedding(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids):
|
||||
output = self.model(input_ids)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaEmbeddingCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids.detach()
|
||||
output = self.model("forward", (input_ids,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CompiledVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module, idx, breakpoints):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
self.idx = idx
|
||||
self.breakpoints = breakpoints
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
if self.breakpoints is None:
|
||||
is_breakpoint = False
|
||||
else:
|
||||
is_breakpoint = self.idx + 1 in self.breakpoints
|
||||
if past_key_value is None:
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
send_to_host=is_breakpoint,
|
||||
)
|
||||
|
||||
if is_breakpoint:
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
else:
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
else:
|
||||
pkv0 = past_key_value[0]
|
||||
pkv1 = past_key_value[1]
|
||||
output = self.model(
|
||||
"second_vicuna_forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
send_to_host=is_breakpoint,
|
||||
)
|
||||
|
||||
if is_breakpoint:
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
else:
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
@@ -1,44 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class SharkLLMBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path=None,
|
||||
max_num_tokens=512,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.hf_model_path = hf_model_path
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.shark_model = None
|
||||
self.device = "cpu"
|
||||
self.precision = "fp32"
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def compile(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate(self, prompt):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_new_token(self, params):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_tokenizer(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_src_model(self):
|
||||
pass
|
||||
|
||||
def load_init_from_config(self):
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,68 +0,0 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
from omegaconf import OmegaConf
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
|
||||
class BaseProcessor:
|
||||
def __init__(self):
|
||||
self.transform = lambda x: x
|
||||
return
|
||||
|
||||
def __call__(self, item):
|
||||
return self.transform(item)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
return cls()
|
||||
|
||||
def build(self, **kwargs):
|
||||
cfg = OmegaConf.create(kwargs)
|
||||
|
||||
return self.from_config(cfg)
|
||||
|
||||
|
||||
class BlipImageBaseProcessor(BaseProcessor):
|
||||
def __init__(self, mean=None, std=None):
|
||||
if mean is None:
|
||||
mean = (0.48145466, 0.4578275, 0.40821073)
|
||||
if std is None:
|
||||
std = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
self.normalize = transforms.Normalize(mean, std)
|
||||
|
||||
|
||||
class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
|
||||
def __init__(self, image_size=224, mean=None, std=None):
|
||||
super().__init__(mean=mean, std=std)
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self, item):
|
||||
return self.transform(item)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
if cfg is None:
|
||||
cfg = OmegaConf.create()
|
||||
|
||||
image_size = cfg.get("image_size", 224)
|
||||
|
||||
mean = cfg.get("mean", None)
|
||||
std = cfg.get("std", None)
|
||||
|
||||
return cls(image_size=image_size, mean=mean, std=std)
|
||||
@@ -1,5 +0,0 @@
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
data_type: images
|
||||
build_info:
|
||||
storage: /path/to/cc_sbu_align/
|
||||
@@ -1,33 +0,0 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
|
||||
# vit encoder
|
||||
image_size: 224
|
||||
drop_path_rate: 0
|
||||
use_grad_checkpoint: False
|
||||
vit_precision: "fp16"
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
|
||||
# Q-Former
|
||||
num_query_token: 32
|
||||
|
||||
# Vicuna
|
||||
llama_model: "lmsys/vicuna-7b-v1.3"
|
||||
|
||||
# generation configs
|
||||
prompt: ""
|
||||
|
||||
preprocess:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
||||
@@ -1,25 +0,0 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
max_txt_len: 160
|
||||
end_sym: "###"
|
||||
low_resource: False
|
||||
prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt"
|
||||
prompt_template: '###Human: {} ###Assistant: '
|
||||
ckpt: 'prerained_minigpt4_7b.pth'
|
||||
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
|
||||
run:
|
||||
task: image_text_pretrain
|
||||
@@ -1,629 +0,0 @@
|
||||
# Based on EVA, BEIT, timm and DeiT code bases
|
||||
# https://github.com/baaivision/EVA
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
# https://github.com/facebookresearch/deit/
|
||||
# https://github.com/facebookresearch/dino
|
||||
# --------------------------------------------------------'
|
||||
import math
|
||||
import requests
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
||||
|
||||
|
||||
def _cfg(url="", **kwargs):
|
||||
return {
|
||||
"url": url,
|
||||
"num_classes": 1000,
|
||||
"input_size": (3, 224, 224),
|
||||
"pool_size": None,
|
||||
"crop_pct": 0.9,
|
||||
"interpolation": "bicubic",
|
||||
"mean": (0.5, 0.5, 0.5),
|
||||
"std": (0.5, 0.5, 0.5),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "p={}".format(self.drop_prob)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
# x = self.drop(x)
|
||||
# commit this for the orignal BERT implement
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
window_size=None,
|
||||
attn_head_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
if attn_head_dim is not None:
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
if window_size:
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1
|
||||
) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(
|
||||
torch.meshgrid([coords_h, coords_w])
|
||||
) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
) # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0
|
||||
).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += (
|
||||
window_size[0] - 1
|
||||
) # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2,
|
||||
dtype=relative_coords.dtype,
|
||||
)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(
|
||||
-1
|
||||
) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer(
|
||||
"relative_position_index", relative_position_index
|
||||
)
|
||||
else:
|
||||
self.window_size = None
|
||||
self.relative_position_bias_table = None
|
||||
self.relative_position_index = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(all_head_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
qkv_bias = torch.cat(
|
||||
(
|
||||
self.q_bias,
|
||||
torch.zeros_like(self.v_bias, requires_grad=False),
|
||||
self.v_bias,
|
||||
)
|
||||
)
|
||||
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = (
|
||||
qkv[0],
|
||||
qkv[1],
|
||||
qkv[2],
|
||||
) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
if self.relative_position_bias_table is not None:
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1
|
||||
).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if rel_pos_bias is not None:
|
||||
attn = attn + rel_pos_bias
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
init_values=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
window_size=None,
|
||||
attn_head_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
window_size=window_size,
|
||||
attn_head_dim=attn_head_dim,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = (
|
||||
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
)
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
if init_values is not None and init_values > 0:
|
||||
self.gamma_1 = nn.Parameter(
|
||||
init_values * torch.ones((dim)), requires_grad=True
|
||||
)
|
||||
self.gamma_2 = nn.Parameter(
|
||||
init_values * torch.ones((dim)), requires_grad=True
|
||||
)
|
||||
else:
|
||||
self.gamma_1, self.gamma_2 = None, None
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(
|
||||
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
|
||||
)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(
|
||||
self.gamma_1
|
||||
* self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
|
||||
)
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (
|
||||
img_size[0] // patch_size[0]
|
||||
)
|
||||
self.patch_shape = (
|
||||
img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1],
|
||||
)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
||||
)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, window_size, num_heads):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1
|
||||
) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
) # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0
|
||||
).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2,
|
||||
dtype=relative_coords.dtype,
|
||||
)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(
|
||||
-1
|
||||
) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer(
|
||||
"relative_position_index", relative_position_index
|
||||
)
|
||||
|
||||
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def forward(self):
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
return relative_position_bias.permute(
|
||||
2, 0, 1
|
||||
).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
init_values=None,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
use_mean_pooling=True,
|
||||
init_scale=0.001,
|
||||
use_checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_size = img_size
|
||||
self.num_classes = num_classes
|
||||
self.num_features = (
|
||||
self.embed_dim
|
||||
) = embed_dim # num_features for consistency with other models
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
if use_abs_pos_emb:
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, embed_dim)
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
if use_shared_rel_pos_bias:
|
||||
self.rel_pos_bias = RelativePositionBias(
|
||||
window_size=self.patch_embed.patch_shape, num_heads=num_heads
|
||||
)
|
||||
else:
|
||||
self.rel_pos_bias = None
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
||||
] # stochastic depth decay rule
|
||||
self.use_rel_pos_bias = use_rel_pos_bias
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
init_values=init_values,
|
||||
window_size=self.patch_embed.patch_shape
|
||||
if use_rel_pos_bias
|
||||
else None,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
||||
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
||||
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
trunc_normal_(self.cls_token, std=0.02)
|
||||
# trunc_normal_(self.mask_token, std=.02)
|
||||
# if isinstance(self.head, nn.Linear):
|
||||
# trunc_normal_(self.head.weight, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
self.fix_init_weight()
|
||||
|
||||
# if isinstance(self.head, nn.Linear):
|
||||
# self.head.weight.data.mul_(init_scale)
|
||||
# self.head.bias.data.mul_(init_scale)
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=""):
|
||||
self.num_classes = num_classes
|
||||
self.head = (
|
||||
nn.Linear(self.embed_dim, num_classes)
|
||||
if num_classes > 0
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(
|
||||
batch_size, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
rel_pos_bias = (
|
||||
self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
)
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
|
||||
else:
|
||||
x = blk(x, rel_pos_bias)
|
||||
return x
|
||||
|
||||
# x = self.norm(x)
|
||||
|
||||
# if self.fc_norm is not None:
|
||||
# t = x[:, 1:, :]
|
||||
# return self.fc_norm(t.mean(1))
|
||||
# else:
|
||||
# return x[:, 0]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
# x = self.head(x)
|
||||
return x
|
||||
|
||||
def get_intermediate_layers(self, x):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(
|
||||
batch_size, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
features = []
|
||||
rel_pos_bias = (
|
||||
self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, rel_pos_bias)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def interpolate_pos_embed(model, checkpoint_model):
|
||||
if "pos_embed" in checkpoint_model:
|
||||
pos_embed_checkpoint = checkpoint_model["pos_embed"].float()
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = model.patch_embed.num_patches
|
||||
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int(
|
||||
(pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5
|
||||
)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches**0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
print(
|
||||
"Position interpolate from %dx%d to %dx%d"
|
||||
% (orig_size, orig_size, new_size, new_size)
|
||||
)
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(
|
||||
-1, orig_size, orig_size, embedding_size
|
||||
).permute(0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens,
|
||||
size=(new_size, new_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
checkpoint_model["pos_embed"] = new_pos_embed
|
||||
|
||||
|
||||
def convert_weights_to_fp16(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
# l.weight.data = l.weight.data.half()
|
||||
l.weight.data = l.weight.data
|
||||
if l.bias is not None:
|
||||
# l.bias.data = l.bias.data.half()
|
||||
l.bias.data = l.bias.data
|
||||
|
||||
# if isinstance(l, (nn.MultiheadAttention, Attention)):
|
||||
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
# tensor = getattr(l, attr)
|
||||
# if tensor is not None:
|
||||
# tensor.data = tensor.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
|
||||
def create_eva_vit_g(
|
||||
img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"
|
||||
):
|
||||
model = VisionTransformer(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
use_mean_pooling=False,
|
||||
embed_dim=1408,
|
||||
depth=39,
|
||||
num_heads=1408 // 88,
|
||||
mlp_ratio=4.3637,
|
||||
qkv_bias=True,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
|
||||
|
||||
local_filename = "eva_vit_g.pth"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
with open(local_filename, "wb") as f:
|
||||
f.write(response.content)
|
||||
print("File downloaded successfully.")
|
||||
state_dict = torch.load(local_filename, map_location="cpu")
|
||||
interpolate_pos_embed(model, state_dict)
|
||||
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if precision == "fp16":
|
||||
# model.to("cuda")
|
||||
convert_weights_to_fp16(model)
|
||||
return model
|
||||
@@ -1,4 +0,0 @@
|
||||
<Img><ImageHere></Img> Describe this image in detail.
|
||||
<Img><ImageHere></Img> Take a look at this image and describe what you notice.
|
||||
<Img><ImageHere></Img> Please provide a detailed description of the picture.
|
||||
<Img><ImageHere></Img> Could you describe the contents of this image for me?
|
||||
@@ -1,300 +0,0 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.src.model_wrappers.stablelm_model import (
|
||||
StableLMModel,
|
||||
)
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="stablelm runner",
|
||||
description="runs a StableLM model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
"--stablelm_vmfb_path", default=None, help="path to StableLM's vmfb"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stablelm_mlir_path",
|
||||
default=None,
|
||||
help="path to StableLM's mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_precompiled_model",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="use the precompiled vmfb",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication token for stablelm-3B model.",
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class SharkStableLM(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
|
||||
max_num_tokens=256,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug="False",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
if precision != "int4" and args.hf_auth_token == None:
|
||||
raise ValueError(
|
||||
""" HF auth token required for StableLM-3B. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
self.hf_auth_token = args.hf_auth_token
|
||||
|
||||
self.precision = precision
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def shouldStop(self, tokens):
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_src_model(self):
|
||||
kwargs = {}
|
||||
if self.precision == "int4":
|
||||
self.hf_model_path = "TheBloke/stablelm-zephyr-3b-GPTQ"
|
||||
from transformers import GPTQConfig
|
||||
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["device_map"] = "cpu"
|
||||
print("[DEBUG] Loading Model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float32,
|
||||
use_auth_token=self.hf_auth_token,
|
||||
**kwargs,
|
||||
)
|
||||
print("[DEBUG] Model loaded successfully")
|
||||
return model
|
||||
|
||||
def get_model_inputs(self):
|
||||
input_ids = torch.randint(3, (1, self.max_sequence_len))
|
||||
attention_mask = torch.randint(3, (1, self.max_sequence_len))
|
||||
return input_ids, attention_mask
|
||||
|
||||
def compile(self):
|
||||
tmp_model_name = f"{self.model_name}_linalg_{self.precision}_seqLen{self.max_sequence_len}"
|
||||
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
model_vmfb_name = None
|
||||
vmfb_path = (
|
||||
Path(tmp_model_name + f"_{self.device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
shark_module = get_vmfb_from_path(
|
||||
vmfb_path, self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
if shark_module is not None:
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(tmp_model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if not mlir_path.exists():
|
||||
model = StableLMModel(self.get_src_model())
|
||||
model_inputs = self.get_model_inputs()
|
||||
from shark.shark_importer import import_with_fx
|
||||
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
model_inputs,
|
||||
is_f16=True if self.precision in ["fp16"] else False,
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*model_inputs],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir at: ", mlir_path)
|
||||
f_.close()
|
||||
del bytecode
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=mlir_path, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
|
||||
def get_tokenizer(self):
|
||||
tok = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path,
|
||||
use_auth_token=self.hf_auth_token,
|
||||
)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
def generate(self, prompt):
|
||||
words_list = []
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
count = 0
|
||||
for i in range(self.max_num_tokens):
|
||||
count = count + 1
|
||||
params = {
|
||||
"new_text": prompt,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params)
|
||||
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
|
||||
if stop_generation:
|
||||
break
|
||||
|
||||
print(detok, end="", flush=True) # this is for CLI and DEBUG
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
prompt = prompt + detok
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} tokens/second\n".format(
|
||||
count / (end - start)
|
||||
)
|
||||
)
|
||||
return words_list
|
||||
|
||||
def generate_new_token(self, params):
|
||||
new_text = params["new_text"]
|
||||
model_inputs = self.tokenizer(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=self.max_sequence_len,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
output = self.shark_model(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if self.shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
|
||||
detok = self.tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
stable_lm = SharkStableLM(
|
||||
model_name="stablelm_zephyr_3b",
|
||||
hf_model_path="stabilityai/stablelm-zephyr-3b",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
)
|
||||
|
||||
default_prompt_text = "The weather is always wonderful"
|
||||
continue_execution = True
|
||||
|
||||
print("\n-----\nScript executing for the following config: \n")
|
||||
print("StableLM Model: ", stable_lm.hf_model_path)
|
||||
print("Precision: ", args.precision)
|
||||
print("Device: ", args.device)
|
||||
|
||||
while continue_execution:
|
||||
use_default_prompt = input(
|
||||
"\nDo you wish to use the default prompt text? Y/N ?: "
|
||||
)
|
||||
if use_default_prompt in ["Y", "y"]:
|
||||
prompt = default_prompt_text
|
||||
else:
|
||||
prompt = input("Please enter the prompt text: ")
|
||||
print("\nPrompt Text: ", prompt)
|
||||
|
||||
res_str = stable_lm.generate(prompt)
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
print(
|
||||
"\n\n-----\nHere's the complete formatted result: \n\n",
|
||||
prompt + "".join(res_str),
|
||||
)
|
||||
continue_execution = input(
|
||||
"\nDo you wish to run script one more time? Y/N ?: "
|
||||
)
|
||||
continue_execution = (
|
||||
True if continue_execution in ["Y", "y"] else False
|
||||
)
|
||||
@@ -1,48 +0,0 @@
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
|
||||
|
||||
# expects a Path / str as arg
|
||||
# returns None if path not found or SharkInference module
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
|
||||
if not isinstance(vmfb_path, Path):
|
||||
vmfb_path = Path(vmfb_path)
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
if not vmfb_path.exists():
|
||||
return None
|
||||
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
print("Device from get_vmfb_from_path - ", device)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
return shark_module
|
||||
|
||||
|
||||
def get_vmfb_from_config(
|
||||
shark_container,
|
||||
model,
|
||||
precision,
|
||||
device,
|
||||
vmfb_path,
|
||||
padding=None,
|
||||
device_id=None,
|
||||
):
|
||||
vmfb_url = (
|
||||
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
|
||||
)
|
||||
if padding:
|
||||
vmfb_url = vmfb_url + f"_{padding}"
|
||||
vmfb_url = vmfb_url + ".vmfb"
|
||||
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
|
||||
return get_vmfb_from_path(
|
||||
vmfb_path, device, "tm_tensor", device_id=device_id
|
||||
)
|
||||
@@ -1,87 +0,0 @@
|
||||
Compile / Run Instructions:
|
||||
|
||||
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
|
||||
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
|
||||
|
||||
Compile Commands FP32/FP16:
|
||||
|
||||
```shell
|
||||
Vulkan AMD:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
|
||||
# use –iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
|
||||
|
||||
CUDA NVIDIA:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
CPU:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu /path/to/input/mlir -o /path/to/output/vmfb
|
||||
```
|
||||
|
||||
|
||||
|
||||
Run / Benchmark Command (FP32 - NCHW):
|
||||
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
|
||||
|
||||
```shell
|
||||
## Vulkan AMD:
|
||||
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
|
||||
|
||||
## CUDA:
|
||||
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=cuda --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
|
||||
|
||||
## CPU:
|
||||
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=local-task --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
|
||||
|
||||
```
|
||||
|
||||
Run via vulkan_gui for RGP Profiling:
|
||||
|
||||
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
|
||||
```shell
|
||||
./build/vulkan_gui/iree-vulkan-gui --module=/path/to/unet.vmfb --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
|
||||
```
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>Debug Commands</summary>
|
||||
|
||||
## Debug commands and other advanced usage follows.
|
||||
|
||||
```shell
|
||||
python txt2img.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
|
||||
```
|
||||
|
||||
## dump all dispatch .spv and isa using amdllpc
|
||||
|
||||
```shell
|
||||
python txt2img.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
|
||||
```
|
||||
|
||||
## Compile and save the .vmfb (using vulkan fp16 as an example):
|
||||
|
||||
```shell
|
||||
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
|
||||
```
|
||||
|
||||
## Capture an RGP trace
|
||||
|
||||
```shell
|
||||
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
|
||||
```
|
||||
|
||||
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
|
||||
|
||||
```shell
|
||||
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf16
|
||||
```
|
||||
|
||||
## Run the unet module with iree-benchmark-module (same config as above):
|
||||
```shell
|
||||
##if you want to use .npz inputs:
|
||||
unzip ~/.local/shark_tank/<your unet>/inputs.npz
|
||||
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --input=@arr_0.npy --input=1xf16 --input=@arr_2.npy --input=@arr_3.npy --input=@arr_4.npy
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -1 +0,0 @@
|
||||
from apps.stable_diffusion.scripts.train_lora_word import lora_train
|
||||
@@ -1,128 +0,0 @@
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
StencilPipeline,
|
||||
resize_stencil,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
image = Image.open(args.img_path).convert("RGB")
|
||||
# When the models get uploaded, it should be default to False.
|
||||
args.import_mlir = True
|
||||
|
||||
use_stencil = args.use_stencil
|
||||
if use_stencil:
|
||||
args.scheduler = "DDIM"
|
||||
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
|
||||
image, args.width, args.height = resize_stencil(image)
|
||||
elif "Shark" in args.scheduler:
|
||||
print(
|
||||
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
|
||||
)
|
||||
args.scheduler = "EulerDiscrete"
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = utils.sanitize_seed(args.seed)
|
||||
# Adjust for height and width based on model
|
||||
|
||||
if use_stencil:
|
||||
img2img_obj = StencilPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_stencil=use_stencil,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
else:
|
||||
img2img_obj = Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = img2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.strength,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
control_mode=args.control_mode,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, strength={args.strength}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += img2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
extra_info = {"STRENGTH": args.strength}
|
||||
save_output_img(generated_imgs[0], seed, extra_info)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,105 +0,0 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
if args.mask_path is None:
|
||||
print("Flag --mask_path is required.")
|
||||
exit()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if "inpaint" in args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
image = Image.open(args.img_path)
|
||||
mask_image = Image.open(args.mask_path)
|
||||
|
||||
inpaint_obj = InpaintPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
custom_vae=args.custom_vae,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
start_time = time.time()
|
||||
generated_imgs = inpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
mask_image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.inpaint_full_res,
|
||||
args.inpaint_full_res_padding,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += f"seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += inpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,19 +0,0 @@
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.scripts import (
|
||||
img2img,
|
||||
txt2img,
|
||||
# inpaint,
|
||||
# outpaint,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.app == "txt2img":
|
||||
txt2img.main()
|
||||
elif args.app == "img2img":
|
||||
img2img.main()
|
||||
# elif args.app == "inpaint":
|
||||
# inpaint.main()
|
||||
# elif args.app == "outpaint":
|
||||
# outpaint.main()
|
||||
else:
|
||||
print(f"args.app value is {args.app} but this isn't supported")
|
||||
@@ -1,120 +0,0 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
OutpaintPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if "inpaint" in args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
image = Image.open(args.img_path)
|
||||
|
||||
outpaint_obj = OutpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
start_time = time.time()
|
||||
generated_imgs = outpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
args.pixels,
|
||||
args.mask_blur,
|
||||
args.left,
|
||||
args.right,
|
||||
args.top,
|
||||
args.bottom,
|
||||
args.noise_q,
|
||||
args.color_variation,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += f"seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += outpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
# save this information as metadata of output generated image.
|
||||
directions = []
|
||||
if args.left:
|
||||
directions.append("left")
|
||||
if args.right:
|
||||
directions.append("right")
|
||||
if args.top:
|
||||
directions.append("up")
|
||||
if args.bottom:
|
||||
directions.append("down")
|
||||
extra_info = {
|
||||
"PIXELS": args.pixels,
|
||||
"MASK_BLUR": args.mask_blur,
|
||||
"DIRECTIONS": directions,
|
||||
"NOISE_Q": args.noise_q,
|
||||
"COLOR_VARIATION": args.color_variation,
|
||||
}
|
||||
save_output_img(generated_imgs[0], seed, extra_info)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,240 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from models.stable_diffusion.main import stable_diff_inf
|
||||
from models.stable_diffusion.utils import get_available_devices
|
||||
from dotenv import load_dotenv
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram import BotCommand
|
||||
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
|
||||
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
|
||||
from io import BytesIO
|
||||
import random
|
||||
|
||||
log = logging.getLogger("TG.Bot")
|
||||
logging.basicConfig()
|
||||
log.warning("Start")
|
||||
load_dotenv()
|
||||
os.environ["AMD_ENABLE_LLPC"] = "0"
|
||||
TG_TOKEN = os.getenv("TG_TOKEN")
|
||||
SELECTED_MODEL = "stablediffusion"
|
||||
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
|
||||
STEPS = 30
|
||||
NEGATIVE_PROMPT = (
|
||||
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
|
||||
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
|
||||
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
|
||||
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
|
||||
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
|
||||
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
|
||||
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
|
||||
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
|
||||
" legs, fused fingers, too many fingers"
|
||||
)
|
||||
GUIDANCE_SCALE = 6
|
||||
available_devices = get_available_devices()
|
||||
models_list = [
|
||||
"stablediffusion",
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]
|
||||
sheds_list = [
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
]
|
||||
|
||||
|
||||
def image_to_bytes(image):
|
||||
bio = BytesIO()
|
||||
bio.name = "image.jpeg"
|
||||
image.save(bio, "JPEG")
|
||||
bio.seek(0)
|
||||
return bio
|
||||
|
||||
|
||||
def get_try_again_markup():
|
||||
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
|
||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
||||
return reply_markup
|
||||
|
||||
|
||||
def generate_image(prompt):
|
||||
seed = random.randint(1, 10000)
|
||||
log.warning(SELECTED_MODEL)
|
||||
log.warning(STEPS)
|
||||
image, text = stable_diff_inf(
|
||||
prompt=prompt,
|
||||
negative_prompt=NEGATIVE_PROMPT,
|
||||
steps=STEPS,
|
||||
guidance_scale=GUIDANCE_SCALE,
|
||||
seed=seed,
|
||||
scheduler_key=SELECTED_SCHEDULER,
|
||||
variant=SELECTED_MODEL,
|
||||
device_key=available_devices[0],
|
||||
)
|
||||
|
||||
return image, seed
|
||||
|
||||
|
||||
async def generate_and_send_photo(
|
||||
update: Update, context: ContextTypes.DEFAULT_TYPE
|
||||
) -> None:
|
||||
progress_msg = await update.message.reply_text(
|
||||
"Generating image...", reply_to_message_id=update.message.message_id
|
||||
)
|
||||
im, seed = generate_image(prompt=update.message.text)
|
||||
await context.bot.delete_message(
|
||||
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
|
||||
)
|
||||
await context.bot.send_photo(
|
||||
update.effective_user.id,
|
||||
image_to_bytes(im),
|
||||
caption=f'"{update.message.text}" (Seed: {seed})',
|
||||
reply_markup=get_try_again_markup(),
|
||||
reply_to_message_id=update.message.message_id,
|
||||
)
|
||||
|
||||
|
||||
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
query = update.callback_query
|
||||
if query.data in models_list:
|
||||
global SELECTED_MODEL
|
||||
SELECTED_MODEL = query.data
|
||||
await query.answer()
|
||||
await query.edit_message_text(text=f"Selected model: {query.data}")
|
||||
return
|
||||
if query.data in sheds_list:
|
||||
global SELECTED_SCHEDULER
|
||||
SELECTED_SCHEDULER = query.data
|
||||
await query.answer()
|
||||
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
|
||||
return
|
||||
replied_message = query.message.reply_to_message
|
||||
await query.answer()
|
||||
progress_msg = await query.message.reply_text(
|
||||
"Generating image...", reply_to_message_id=replied_message.message_id
|
||||
)
|
||||
|
||||
if query.data == "TRYAGAIN":
|
||||
prompt = replied_message.text
|
||||
im, seed = generate_image(prompt)
|
||||
|
||||
await context.bot.delete_message(
|
||||
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
|
||||
)
|
||||
await context.bot.send_photo(
|
||||
update.effective_user.id,
|
||||
image_to_bytes(im),
|
||||
caption=f'"{prompt}" (Seed: {seed})',
|
||||
reply_markup=get_try_again_markup(),
|
||||
reply_to_message_id=replied_message.message_id,
|
||||
)
|
||||
|
||||
|
||||
async def select_model_handler(update, context):
|
||||
text = "Select model"
|
||||
keyboard = []
|
||||
for model in models_list:
|
||||
keyboard.append(
|
||||
[
|
||||
InlineKeyboardButton(text=model, callback_data=model),
|
||||
]
|
||||
)
|
||||
markup = InlineKeyboardMarkup(keyboard)
|
||||
await update.message.reply_text(text=text, reply_markup=markup)
|
||||
|
||||
|
||||
async def select_scheduler_handler(update, context):
|
||||
text = "Select schedule"
|
||||
keyboard = []
|
||||
for shed in sheds_list:
|
||||
keyboard.append(
|
||||
[
|
||||
InlineKeyboardButton(text=shed, callback_data=shed),
|
||||
]
|
||||
)
|
||||
markup = InlineKeyboardMarkup(keyboard)
|
||||
await update.message.reply_text(text=text, reply_markup=markup)
|
||||
|
||||
|
||||
async def set_steps_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_steps ")[1]
|
||||
global STEPS
|
||||
STEPS = int(input_args)
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_steps 30"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def set_negative_prompt_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_negative_prompt ")[1]
|
||||
global NEGATIVE_PROMPT
|
||||
NEGATIVE_PROMPT = input_args
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_negative_prompt ugly, bad art, mutated"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def set_guidance_scale_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_guidance_scale ")[1]
|
||||
global GUIDANCE_SCALE
|
||||
GUIDANCE_SCALE = int(input_args)
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_guidance_scale 7"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def setup_bot_commands(application: Application) -> None:
|
||||
await application.bot.set_my_commands(
|
||||
[
|
||||
BotCommand("select_model", "to select model"),
|
||||
BotCommand("select_scheduler", "to select scheduler"),
|
||||
BotCommand("set_steps", "to set steps"),
|
||||
BotCommand("set_guidance_scale", "to set guidance scale"),
|
||||
BotCommand("set_negative_prompt", "to set negative prompt"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
app = (
|
||||
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
|
||||
)
|
||||
app.add_handler(CommandHandler("select_model", select_model_handler))
|
||||
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
|
||||
app.add_handler(CommandHandler("set_steps", set_steps_handler))
|
||||
app.add_handler(
|
||||
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
|
||||
)
|
||||
app.add_handler(
|
||||
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
|
||||
)
|
||||
app.add_handler(
|
||||
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
|
||||
)
|
||||
app.add_handler(CallbackQueryHandler(button))
|
||||
log.warning("Start bot")
|
||||
app.run_polling()
|
||||
@@ -1,693 +0,0 @@
|
||||
# Install the required libs
|
||||
# pip install -U git+https://github.com/huggingface/diffusers.git
|
||||
# pip install accelerate transformers ftfy
|
||||
|
||||
# HuggingFace Token
|
||||
# YOUR_TOKEN = "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"
|
||||
|
||||
|
||||
# Import required libraries
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
import random
|
||||
import torch_mlir
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import PIL
|
||||
import logging
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
import torch._dynamo as dynamo
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
torch._dynamo.config.verbose = True
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import (
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
clear_all,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import update_lora_weight
|
||||
|
||||
|
||||
# Setup the dataset
|
||||
class LoraDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
set="train",
|
||||
prompt="myloraprompt",
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.prompt = prompt
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
def torch_device(device):
|
||||
device_tokens = device.split("=>")
|
||||
if len(device_tokens) == 1:
|
||||
device_str = device_tokens[0].strip()
|
||||
else:
|
||||
device_str = device_tokens[1].strip()
|
||||
device_type_tokens = device_str.split("://")
|
||||
if device_type_tokens[0] == "metal":
|
||||
device_type_tokens[0] = "vulkan"
|
||||
if len(device_type_tokens) > 1:
|
||||
return device_type_tokens[0] + ":" + device_type_tokens[1]
|
||||
else:
|
||||
return device_type_tokens[0]
|
||||
|
||||
|
||||
########## Setting up the model ##########
|
||||
def lora_train(
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
training_images_dir: str,
|
||||
lora_save_dir: str,
|
||||
use_lora: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
print(
|
||||
"Note LoRA training is not compatible with the latest torch-mlir branch"
|
||||
)
|
||||
print(
|
||||
"To run LoRA training you'll need this to follow this guide for the torch-mlir branch: https://github.com/nod-ai/SHARK/tree/main/shark/examples/shark_training/stable_diffusion"
|
||||
)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.steps = steps
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be "
|
||||
"empty.",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.training_images_dir = training_images_dir
|
||||
args.lora_save_dir = lora_save_dir
|
||||
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = torch_device(device)
|
||||
args.use_lora = use_lora
|
||||
|
||||
# Load the Stable Diffusion model
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.hf_model_id, subfolder="text_encoder"
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.hf_model_id, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.hf_model_id, subfolder="unet"
|
||||
)
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze everything but LoRA
|
||||
freeze_params(vae.parameters())
|
||||
freeze_params(unet.parameters())
|
||||
freeze_params(text_encoder.parameters())
|
||||
|
||||
# Move vae and unet to device
|
||||
vae.to(args.device)
|
||||
unet.to(args.device)
|
||||
text_encoder.to(args.device)
|
||||
|
||||
if use_lora != "":
|
||||
update_lora_weight(unet, args.use_lora, "unet")
|
||||
else:
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = (
|
||||
None
|
||||
if name.endswith("attn1.processor")
|
||||
else unet.config.cross_attention_dim
|
||||
)
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[
|
||||
block_id
|
||||
]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
unet.set_attn_processor(lora_attn_procs)
|
||||
lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.encode(input, return_dict=False)[0]
|
||||
return x
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
shark_vae = VaeModel()
|
||||
shark_unet = UnetModel()
|
||||
|
||||
####### Creating our training data ########
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.hf_model_id,
|
||||
subfolder="tokenizer",
|
||||
)
|
||||
|
||||
# Let's create the Dataset and Dataloader
|
||||
train_dataset = LoraDataset(
|
||||
data_root=args.training_images_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=vae.sample_size,
|
||||
prompt=args.prompts[0],
|
||||
repeats=100,
|
||||
center_crop=False,
|
||||
set="train",
|
||||
)
|
||||
|
||||
def create_dataloader(train_batch_size=1):
|
||||
return torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True
|
||||
)
|
||||
|
||||
# Create noise_scheduler for training
|
||||
noise_scheduler = DDPMScheduler.from_config(
|
||||
args.hf_model_id, subfolder="scheduler"
|
||||
)
|
||||
|
||||
######## Training ###########
|
||||
|
||||
# Define hyperparameters for our training. If you are not happy with your results,
|
||||
# you can tune the `learning_rate` and the `max_train_steps`
|
||||
|
||||
# Setting up all training args
|
||||
hyperparameters = {
|
||||
"learning_rate": 5e-04,
|
||||
"scale_lr": True,
|
||||
"max_train_steps": steps,
|
||||
"train_batch_size": batch_size,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": True,
|
||||
"mixed_precision": "fp16",
|
||||
"seed": 42,
|
||||
"output_dir": "sd-concept-output",
|
||||
}
|
||||
# creating output directory
|
||||
cwd = os.getcwd()
|
||||
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
|
||||
while not os.path.exists(str(out_dir)):
|
||||
try:
|
||||
os.mkdir(out_dir)
|
||||
except OSError as error:
|
||||
print("Output directory not created")
|
||||
|
||||
###### Torch-MLIR Compilation ######
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
@make_simple_dynamo_backend
|
||||
def refbackend_torchdynamo_backend(
|
||||
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
):
|
||||
# handling usage of empty tensor without initializing
|
||||
transform_fx(fx_graph)
|
||||
fx_graph.recompile()
|
||||
if _returns_nothing(fx_graph):
|
||||
return fx_graph
|
||||
removed_none_indexes = _remove_nones(fx_graph)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
|
||||
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
|
||||
bytecode_stream = BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
def compiled_callable(*inputs):
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
result = shark_module("forward", inputs)
|
||||
if was_unwrapped:
|
||||
result = [
|
||||
result,
|
||||
]
|
||||
if not isinstance(result, list):
|
||||
result = torch.from_numpy(result)
|
||||
else:
|
||||
result = tuple(torch.from_numpy(x) for x in result)
|
||||
result = list(result)
|
||||
for removed_index in removed_none_indexes:
|
||||
result.insert(removed_index, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
return compiled_callable
|
||||
|
||||
def predictions(torch_func, jit_func, batchA, batchB):
|
||||
res = jit_func(batchA.numpy(), batchB.numpy())
|
||||
if res is not None:
|
||||
# prediction = torch.from_numpy(res)
|
||||
prediction = res
|
||||
else:
|
||||
prediction = None
|
||||
return prediction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
train_batch_size = hyperparameters["train_batch_size"]
|
||||
gradient_accumulation_steps = hyperparameters[
|
||||
"gradient_accumulation_steps"
|
||||
]
|
||||
learning_rate = hyperparameters["learning_rate"]
|
||||
if hyperparameters["scale_lr"]:
|
||||
learning_rate = (
|
||||
learning_rate
|
||||
* gradient_accumulation_steps
|
||||
* train_batch_size
|
||||
# * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
lora_layers.parameters(), # only optimize the embeddings
|
||||
lr=learning_rate,
|
||||
)
|
||||
|
||||
# Training function
|
||||
def train_func(batch_pixel_values, batch_input_ids):
|
||||
# Convert images to latent space
|
||||
latents = shark_vae(batch_pixel_values).sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch_input_ids)[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = shark_unet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
)
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
||||
)
|
||||
|
||||
loss = (
|
||||
F.mse_loss(noise_pred, target, reduction="none")
|
||||
.mean([1, 2, 3])
|
||||
.mean()
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
return loss
|
||||
|
||||
def training_function():
|
||||
max_train_steps = hyperparameters["max_train_steps"]
|
||||
output_dir = hyperparameters["output_dir"]
|
||||
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
|
||||
|
||||
train_dataloader = create_dataloader(train_batch_size)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(
|
||||
max_train_steps / num_update_steps_per_epoch
|
||||
)
|
||||
|
||||
# Train!
|
||||
total_batch_size = (
|
||||
train_batch_size
|
||||
* gradient_accumulation_steps
|
||||
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(
|
||||
f" Instantaneous batch size per device = {train_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
# range(max_train_steps), disable=not accelerator.is_local_main_process
|
||||
range(max_train_steps)
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
params__ = [
|
||||
i for i in text_encoder.get_input_embeddings().parameters()
|
||||
]
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
dynamo_callable = dynamo.optimize(
|
||||
refbackend_torchdynamo_backend
|
||||
)(train_func)
|
||||
lam_func = lambda x, y: dynamo_callable(
|
||||
torch.from_numpy(x), torch.from_numpy(y)
|
||||
)
|
||||
loss = predictions(
|
||||
train_func,
|
||||
lam_func,
|
||||
batch["pixel_values"],
|
||||
batch["input_ids"],
|
||||
)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {"loss": loss.detach().item()}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
training_function()
|
||||
|
||||
# Save the lora weights
|
||||
unet.save_attn_procs(args.lora_save_dir)
|
||||
|
||||
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
|
||||
if param.grad is not None:
|
||||
del param.grad # free some memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
if len(args.prompts) != 1:
|
||||
print("Need exactly one prompt for the LoRA word")
|
||||
lora_train(
|
||||
args.prompts[0],
|
||||
args.height,
|
||||
args.width,
|
||||
args.training_steps,
|
||||
args.guidance_scale,
|
||||
args.seed,
|
||||
args.batch_count,
|
||||
args.batch_size,
|
||||
args.scheduler,
|
||||
"None",
|
||||
args.hf_model_id,
|
||||
args.precision,
|
||||
args.device,
|
||||
args.max_length,
|
||||
args.training_images_dir,
|
||||
args.lora_save_dir,
|
||||
args.use_lora,
|
||||
)
|
||||
@@ -1,131 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shark_tuner.codegen_tuner import SharkCodegenTuner
|
||||
from shark_tuner.iree_utils import (
|
||||
dump_dispatches,
|
||||
create_context,
|
||||
export_module_to_mlir_file,
|
||||
)
|
||||
from shark_tuner.model_annotation import model_annotation
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.utils import set_init_device_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import (
|
||||
get_device_args,
|
||||
load_winograd_configs,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
|
||||
|
||||
def load_mlir_module():
|
||||
if "upscaler" in args.hf_model_id:
|
||||
is_upscaler = True
|
||||
else:
|
||||
is_upscaler = False
|
||||
sd_model = SharkifyStableDiffusionModel(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
max_len=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
is_upscaler=is_upscaler,
|
||||
use_tuned=False,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
return_mlir=True,
|
||||
)
|
||||
|
||||
if args.annotation_model == "unet":
|
||||
mlir_module = sd_model.unet()
|
||||
model_name = sd_model.model_name["unet"]
|
||||
elif args.annotation_model == "vae":
|
||||
mlir_module = sd_model.vae()
|
||||
model_name = sd_model.model_name["vae"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{args.annotation_model} is not supported for tuning."
|
||||
)
|
||||
|
||||
return mlir_module, model_name
|
||||
|
||||
|
||||
def main():
|
||||
args.use_tuned = False
|
||||
set_init_device_flags()
|
||||
mlir_module, model_name = load_mlir_module()
|
||||
|
||||
# Get device and device specific arguments
|
||||
device, device_spec_args = get_device_args()
|
||||
device_spec = ""
|
||||
vulkan_target_triple = ""
|
||||
if device_spec_args:
|
||||
device_spec = device_spec_args[-1].split("=")[-1].strip()
|
||||
if device == "vulkan":
|
||||
vulkan_target_triple = device_spec
|
||||
device_spec = device_spec.split("-")[0]
|
||||
|
||||
# Add winograd annotation for vulkan device
|
||||
use_winograd = (
|
||||
True
|
||||
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
|
||||
else False
|
||||
)
|
||||
winograd_config = (
|
||||
load_winograd_configs()
|
||||
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
|
||||
else ""
|
||||
)
|
||||
with create_context() as ctx:
|
||||
input_module = model_annotation(
|
||||
ctx,
|
||||
input_contents=mlir_module,
|
||||
config_path=winograd_config,
|
||||
search_op="conv",
|
||||
winograd=use_winograd,
|
||||
)
|
||||
|
||||
# Dump model dispatches
|
||||
generates_dir = Path.home() / "tmp"
|
||||
if not os.path.exists(generates_dir):
|
||||
os.makedirs(generates_dir)
|
||||
dump_mlir = generates_dir / "temp.mlir"
|
||||
dispatch_dir = generates_dir / f"{model_name}_{device_spec}_dispatches"
|
||||
export_module_to_mlir_file(input_module, dump_mlir)
|
||||
dump_dispatches(
|
||||
dump_mlir,
|
||||
device,
|
||||
dispatch_dir,
|
||||
vulkan_target_triple,
|
||||
use_winograd=use_winograd,
|
||||
)
|
||||
|
||||
# Tune each dispatch
|
||||
dtype = "f16" if args.precision == "fp16" else "f32"
|
||||
config_filename = f"{model_name}_{device_spec}_configs.json"
|
||||
|
||||
for f_path in os.listdir(dispatch_dir):
|
||||
if not f_path.endswith(".mlir"):
|
||||
continue
|
||||
|
||||
model_dir = os.path.join(dispatch_dir, f_path)
|
||||
|
||||
tuner = SharkCodegenTuner(
|
||||
model_dir,
|
||||
device,
|
||||
"random",
|
||||
args.num_iters,
|
||||
args.tuned_config_dir,
|
||||
dtype,
|
||||
args.search_op,
|
||||
batch_size=1,
|
||||
config_filename=config_filename,
|
||||
use_dispatch=True,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
)
|
||||
tuner.tune()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,88 +0,0 @@
|
||||
import torch
|
||||
import transformers
|
||||
import time
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImagePipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
start_time = time.time()
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += (
|
||||
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
|
||||
)
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,96 +0,0 @@
|
||||
import torch
|
||||
import time
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImageSDXLPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
|
||||
args.precision = "fp16"
|
||||
args.height = 1024
|
||||
args.width = 1024
|
||||
args.max_length = 77
|
||||
args.scheduler = "DDIM"
|
||||
print(
|
||||
"Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM"
|
||||
)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
start_time = time.time()
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += (
|
||||
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
|
||||
)
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,92 +0,0 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
UpscalerPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
# When the models get uploaded, it should be defaulted to False.
|
||||
args.import_mlir = True
|
||||
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
image = (
|
||||
Image.open(args.img_path)
|
||||
.convert("RGB")
|
||||
.resize((args.height, args.width))
|
||||
)
|
||||
seed = utils.sanitize_seed(args.seed)
|
||||
# Adjust for height and width based on model
|
||||
|
||||
upscaler_obj = UpscalerPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
ddpm_scheduler=schedulers["DDPM"],
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = upscaler_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.noise_level,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, noise_level={args.noise_level}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += upscaler_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
extra_info = {"NOISE LEVEL": args.noise_level}
|
||||
save_output_img(generated_imgs[0], seed, extra_info)
|
||||
print(text_output)
|
||||
@@ -1,48 +0,0 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from apps.stable_diffusion.shark_studio_imports import pathex, datas, hiddenimports
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
a = Analysis(
|
||||
['web/index.py'],
|
||||
pathex=pathex,
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
module_collection_mode={
|
||||
'gradio': 'py', # Collect gradio package as source .py files
|
||||
},
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='nodai_shark_studio',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=False,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
@@ -1,85 +0,0 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('shark')
|
||||
datas += collect_data_files('py-cpuinfo')
|
||||
datas += [
|
||||
( 'src/utils/resources/prompts.json', 'resources' ),
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
( 'src/utils/resources/opt_flags.json', 'resources' ),
|
||||
( 'src/utils/resources/base_model.json', 'resources' ),
|
||||
]
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/main.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_sd_cli',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
@@ -1,90 +0,0 @@
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
|
||||
import sys
|
||||
|
||||
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
# python path for pyinstaller
|
||||
pathex = [
|
||||
".",
|
||||
"./apps/language_models/langchain",
|
||||
"./apps/language_models/src/pipelines/minigpt4_utils",
|
||||
]
|
||||
|
||||
# datafiles for pyinstaller
|
||||
datas = []
|
||||
datas += copy_metadata("torch")
|
||||
datas += copy_metadata("tokenizers")
|
||||
datas += copy_metadata("tqdm")
|
||||
datas += copy_metadata("regex")
|
||||
datas += copy_metadata("requests")
|
||||
datas += copy_metadata("packaging")
|
||||
datas += copy_metadata("filelock")
|
||||
datas += copy_metadata("numpy")
|
||||
datas += copy_metadata("importlib_metadata")
|
||||
datas += copy_metadata("torch-mlir")
|
||||
datas += copy_metadata("omegaconf")
|
||||
datas += copy_metadata("safetensors")
|
||||
datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += copy_metadata("gradio")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files("diffusers")
|
||||
datas += collect_data_files("transformers")
|
||||
datas += collect_data_files("pytorch_lightning")
|
||||
datas += collect_data_files("skimage")
|
||||
datas += collect_data_files("gradio")
|
||||
datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree")
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
datas += collect_data_files("cv2")
|
||||
datas += collect_data_files("einops")
|
||||
datas += [
|
||||
("src/utils/resources/prompts.json", "resources"),
|
||||
("src/utils/resources/model_db.json", "resources"),
|
||||
("src/utils/resources/opt_flags.json", "resources"),
|
||||
("src/utils/resources/base_model.json", "resources"),
|
||||
("web/ui/css/*", "ui/css"),
|
||||
("web/ui/logos/*", "logos"),
|
||||
(
|
||||
"../language_models/src/pipelines/minigpt4_utils/configs/*",
|
||||
"minigpt4_utils/configs",
|
||||
),
|
||||
(
|
||||
"../language_models/src/pipelines/minigpt4_utils/prompts/*",
|
||||
"minigpt4_utils/prompts",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("diffusers") if "tests" not in x
|
||||
]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
for x in collect_submodules("transformers")
|
||||
if not any(kw in x for kw in blacklist)
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
hiddenimports += ["iree._runtime"]
|
||||
@@ -1,19 +0,0 @@
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
args,
|
||||
set_init_device_flags,
|
||||
prompt_examples,
|
||||
get_available_devices,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
resize_stencil,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines import (
|
||||
Text2ImagePipeline,
|
||||
Text2ImageSDXLPipeline,
|
||||
Image2ImagePipeline,
|
||||
InpaintPipeline,
|
||||
OutpaintPipeline,
|
||||
StencilPipeline,
|
||||
UpscalerPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import get_schedulers
|
||||
@@ -1,12 +0,0 @@
|
||||
from apps.stable_diffusion.src.models.model_wrappers import (
|
||||
SharkifyStableDiffusionModel,
|
||||
)
|
||||
from apps.stable_diffusion.src.models.opt_params import (
|
||||
get_vae_encode,
|
||||
get_vae,
|
||||
get_unet,
|
||||
get_clip,
|
||||
get_tokenizer,
|
||||
get_params,
|
||||
get_variant_version,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,133 +0,0 @@
|
||||
import sys
|
||||
from transformers import CLIPTokenizer
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
models_db,
|
||||
args,
|
||||
get_shark_model,
|
||||
get_opt_flags,
|
||||
)
|
||||
|
||||
|
||||
hf_model_variant_map = {
|
||||
"Linaqruf/anything-v3.0": ["anythingv3", "v1_4"],
|
||||
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v1_4"],
|
||||
"prompthero/openjourney": ["openjourney", "v1_4"],
|
||||
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v1_4"],
|
||||
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
|
||||
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
|
||||
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
|
||||
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
|
||||
"stabilityai/stable-diffusion-2-inpainting": [
|
||||
"stablediffusion",
|
||||
"inpaint_v2",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# TODO: Add the quantized model as a part model_db.json.
|
||||
# This is currently in experimental phase.
|
||||
def get_quantize_model():
|
||||
bucket_key = "gs://shark_tank/prashant_nod"
|
||||
model_key = "unet_int8"
|
||||
iree_flags = get_opt_flags("unet", precision="fp16")
|
||||
if args.height != 512 and args.width != 512 and args.max_length != 77:
|
||||
sys.exit(
|
||||
"The int8 quantized model currently requires the height and width to be 512, and max_length to be 77"
|
||||
)
|
||||
return bucket_key, model_key, iree_flags
|
||||
|
||||
|
||||
def get_variant_version(hf_model_id):
|
||||
return hf_model_variant_map[hf_model_id]
|
||||
|
||||
|
||||
def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
try:
|
||||
bucket = models_db[0][bucket_key]
|
||||
model_name = models_db[1][model_key]
|
||||
except KeyError:
|
||||
raise Exception(
|
||||
f"{bucket_key}/{model_key} is not present in the models database"
|
||||
)
|
||||
iree_flags = get_opt_flags(model, precision="fp16")
|
||||
return bucket, model_name, iree_flags
|
||||
|
||||
|
||||
def get_unet():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
|
||||
# TODO: Get the quantize model from model_db.json
|
||||
if args.use_quantize == "int8":
|
||||
bk, mk, flags = get_quantize_model()
|
||||
return get_shark_model(bk, mk, flags)
|
||||
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "unet", is_tuned, args.precision
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae_encode():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "vae", is_tuned, args.precision
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "vae", is_tuned, args.precision
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_clip():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
bucket_key = f"{variant}/untuned"
|
||||
model_key = (
|
||||
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
|
||||
)
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "clip", "untuned", "fp32"
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_tokenizer(subfolder="tokenizer", hf_model_id=None):
|
||||
if hf_model_id is not None:
|
||||
args.hf_model_id = hf_model_id
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.hf_model_id, subfolder=subfolder
|
||||
)
|
||||
return tokenizer
|
||||
@@ -1,21 +0,0 @@
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
|
||||
Text2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
|
||||
Text2ImageSDXLPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
|
||||
InpaintPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpaint import (
|
||||
OutpaintPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_stencil import (
|
||||
StencilPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_upscaler import (
|
||||
UpscalerPipeline,
|
||||
)
|
||||
@@ -1,231 +0,0 @@
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from random import randint
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
|
||||
class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
if self.vae_encode is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
else:
|
||||
try:
|
||||
self.vae_encode = get_vae_encode()
|
||||
except:
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
|
||||
def unload_vae_encode(self):
|
||||
del self.vae_encode
|
||||
self.vae_encode = None
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
dtype,
|
||||
resample_type,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre-process image
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
image = image.resize((width, height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
|
||||
# set scheduler steps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
init_timestep = min(
|
||||
int(num_inference_steps * strength), num_inference_steps
|
||||
)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
# timesteps reduced as per strength
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
# new number of steps to be used as per strength will be
|
||||
# num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# image encode
|
||||
latents = self.encode_image((image_arr,))
|
||||
latents = torch.from_numpy(latents).to(dtype)
|
||||
# add noise to data
|
||||
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents, noise, timesteps[0].repeat(1)
|
||||
)
|
||||
|
||||
return latents, timesteps
|
||||
|
||||
def encode_image(self, input_image):
|
||||
self.load_vae_encode()
|
||||
vae_encode_start = time.time()
|
||||
latents = self.vae_encode("forward", input_image)
|
||||
vae_inf_time = (time.time() - vae_encode_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_vae_encode()
|
||||
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
return latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
stencils,
|
||||
images,
|
||||
resample_type,
|
||||
control_mode,
|
||||
preprocessed_hints=[],
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Prepare input image latent
|
||||
image_latents, final_timesteps = self.prepare_image_latents(
|
||||
image=image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=image_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=final_timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
@@ -1,487 +0,0 @@
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from PIL import Image, ImageOps
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
|
||||
|
||||
class InpaintPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
if self.vae_encode is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
else:
|
||||
try:
|
||||
self.vae_encode = get_vae_encode()
|
||||
except:
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
|
||||
def unload_vae_encode(self):
|
||||
del self.vae_encode
|
||||
self.vae_encode = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def get_crop_region(self, mask, pad=0):
|
||||
h, w = mask.shape
|
||||
|
||||
crop_left = 0
|
||||
for i in range(w):
|
||||
if not (mask[:, i] == 0).all():
|
||||
break
|
||||
crop_left += 1
|
||||
|
||||
crop_right = 0
|
||||
for i in reversed(range(w)):
|
||||
if not (mask[:, i] == 0).all():
|
||||
break
|
||||
crop_right += 1
|
||||
|
||||
crop_top = 0
|
||||
for i in range(h):
|
||||
if not (mask[i] == 0).all():
|
||||
break
|
||||
crop_top += 1
|
||||
|
||||
crop_bottom = 0
|
||||
for i in reversed(range(h)):
|
||||
if not (mask[i] == 0).all():
|
||||
break
|
||||
crop_bottom += 1
|
||||
|
||||
return (
|
||||
int(max(crop_left - pad, 0)),
|
||||
int(max(crop_top - pad, 0)),
|
||||
int(min(w - crop_right + pad, w)),
|
||||
int(min(h - crop_bottom + pad, h)),
|
||||
)
|
||||
|
||||
def expand_crop_region(
|
||||
self,
|
||||
crop_region,
|
||||
processing_width,
|
||||
processing_height,
|
||||
image_width,
|
||||
image_height,
|
||||
):
|
||||
x1, y1, x2, y2 = crop_region
|
||||
|
||||
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
||||
ratio_processing = processing_width / processing_height
|
||||
|
||||
if ratio_crop_region > ratio_processing:
|
||||
desired_height = (x2 - x1) / ratio_processing
|
||||
desired_height_diff = int(desired_height - (y2 - y1))
|
||||
y1 -= desired_height_diff // 2
|
||||
y2 += desired_height_diff - desired_height_diff // 2
|
||||
if y2 >= image_height:
|
||||
diff = y2 - image_height
|
||||
y2 -= diff
|
||||
y1 -= diff
|
||||
if y1 < 0:
|
||||
y2 -= y1
|
||||
y1 -= y1
|
||||
if y2 >= image_height:
|
||||
y2 = image_height
|
||||
else:
|
||||
desired_width = (y2 - y1) * ratio_processing
|
||||
desired_width_diff = int(desired_width - (x2 - x1))
|
||||
x1 -= desired_width_diff // 2
|
||||
x2 += desired_width_diff - desired_width_diff // 2
|
||||
if x2 >= image_width:
|
||||
diff = x2 - image_width
|
||||
x2 -= diff
|
||||
x1 -= diff
|
||||
if x1 < 0:
|
||||
x2 -= x1
|
||||
x1 -= x1
|
||||
if x2 >= image_width:
|
||||
x2 = image_width
|
||||
|
||||
return x1, y1, x2, y2
|
||||
|
||||
def resize_image(self, resize_mode, im, width, height):
|
||||
"""
|
||||
resize_mode:
|
||||
0: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
||||
1: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
||||
"""
|
||||
|
||||
if resize_mode == 0:
|
||||
ratio = width / height
|
||||
src_ratio = im.width / im.height
|
||||
|
||||
src_w = (
|
||||
width if ratio > src_ratio else im.width * height // im.height
|
||||
)
|
||||
src_h = (
|
||||
height if ratio <= src_ratio else im.height * width // im.width
|
||||
)
|
||||
|
||||
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(
|
||||
resized,
|
||||
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
|
||||
)
|
||||
|
||||
else:
|
||||
ratio = width / height
|
||||
src_ratio = im.width / im.height
|
||||
|
||||
src_w = (
|
||||
width if ratio < src_ratio else im.width * height // im.height
|
||||
)
|
||||
src_h = (
|
||||
height if ratio >= src_ratio else im.height * width // im.width
|
||||
)
|
||||
|
||||
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(
|
||||
resized,
|
||||
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
|
||||
)
|
||||
|
||||
if ratio < src_ratio:
|
||||
fill_height = height // 2 - src_h // 2
|
||||
res.paste(
|
||||
resized.resize((width, fill_height), box=(0, 0, width, 0)),
|
||||
box=(0, 0),
|
||||
)
|
||||
res.paste(
|
||||
resized.resize(
|
||||
(width, fill_height),
|
||||
box=(0, resized.height, width, resized.height),
|
||||
),
|
||||
box=(0, fill_height + src_h),
|
||||
)
|
||||
elif ratio > src_ratio:
|
||||
fill_width = width // 2 - src_w // 2
|
||||
res.paste(
|
||||
resized.resize(
|
||||
(fill_width, height), box=(0, 0, 0, height)
|
||||
),
|
||||
box=(0, 0),
|
||||
)
|
||||
res.paste(
|
||||
resized.resize(
|
||||
(fill_width, height),
|
||||
box=(resized.width, 0, resized.width, height),
|
||||
),
|
||||
box=(fill_width + src_w, 0),
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def prepare_mask_and_masked_image(
|
||||
self,
|
||||
image,
|
||||
mask,
|
||||
height,
|
||||
width,
|
||||
inpaint_full_res,
|
||||
inpaint_full_res_padding,
|
||||
):
|
||||
# preprocess image
|
||||
image = image.resize((width, height))
|
||||
mask = mask.resize((width, height))
|
||||
|
||||
paste_to = ()
|
||||
overlay_image = None
|
||||
if inpaint_full_res:
|
||||
# prepare overlay image
|
||||
overlay_image = Image.new("RGB", (image.width, image.height))
|
||||
overlay_image.paste(
|
||||
image.convert("RGB"),
|
||||
mask=ImageOps.invert(mask.convert("L")),
|
||||
)
|
||||
|
||||
# prepare mask
|
||||
mask = mask.convert("L")
|
||||
crop_region = self.get_crop_region(
|
||||
np.array(mask), inpaint_full_res_padding
|
||||
)
|
||||
crop_region = self.expand_crop_region(
|
||||
crop_region, width, height, mask.width, mask.height
|
||||
)
|
||||
x1, y1, x2, y2 = crop_region
|
||||
mask = mask.crop(crop_region)
|
||||
mask = self.resize_image(1, mask, width, height)
|
||||
paste_to = (x1, y1, x2 - x1, y2 - y1)
|
||||
|
||||
# prepare image
|
||||
image = image.crop(crop_region)
|
||||
image = self.resize_image(1, image, width, height)
|
||||
|
||||
if isinstance(image, (Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], Image.Image):
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
|
||||
mask = np.concatenate(
|
||||
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
||||
)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image, paste_to, overlay_image
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
):
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // 8, width // 8)
|
||||
)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
self.load_vae_encode()
|
||||
masked_image = masked_image.to(dtype)
|
||||
masked_image_latents = self.vae_encode("forward", (masked_image,))
|
||||
masked_image_latents = torch.from_numpy(masked_image_latents)
|
||||
if self.ondemand:
|
||||
self.unload_vae_encode()
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
return mask, masked_image_latents
|
||||
|
||||
def apply_overlay(self, image, paste_loc, overlay):
|
||||
x, y, w, h = paste_loc
|
||||
image = self.resize_image(0, image, w, h)
|
||||
overlay.paste(image, (x, y))
|
||||
|
||||
return overlay
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
mask_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
inpaint_full_res,
|
||||
inpaint_full_res_padding,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Preprocess mask and image
|
||||
(
|
||||
mask,
|
||||
masked_image,
|
||||
paste_to,
|
||||
overlay_image,
|
||||
) = self.prepare_mask_and_masked_image(
|
||||
image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
inpaint_full_res,
|
||||
inpaint_full_res_padding,
|
||||
)
|
||||
|
||||
# Prepare mask latent variables
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask=mask,
|
||||
masked_image=masked_image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
mask=mask,
|
||||
masked_image_latents=masked_image_latents,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
if inpaint_full_res:
|
||||
output_image = self.apply_overlay(
|
||||
all_imgs[0], paste_to, overlay_image
|
||||
)
|
||||
return [output_image]
|
||||
|
||||
return all_imgs
|
||||
@@ -1,581 +0,0 @@
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from PIL import Image, ImageDraw, ImageFilter
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
import math
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
|
||||
|
||||
class OutpaintPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
if self.vae_encode is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
else:
|
||||
try:
|
||||
self.vae_encode = get_vae_encode()
|
||||
except:
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
|
||||
def unload_vae_encode(self):
|
||||
del self.vae_encode
|
||||
self.vae_encode = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_mask_and_masked_image(
|
||||
self, image, mask, mask_blur, width, height
|
||||
):
|
||||
if mask_blur > 0:
|
||||
mask = mask.filter(ImageFilter.GaussianBlur(mask_blur))
|
||||
image = image.resize((width, height))
|
||||
mask = mask.resize((width, height))
|
||||
|
||||
# preprocess image
|
||||
if isinstance(image, (Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], Image.Image):
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
|
||||
mask = np.concatenate(
|
||||
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
||||
)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
):
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // 8, width // 8)
|
||||
)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
self.load_vae_encode()
|
||||
masked_image = masked_image.to(dtype)
|
||||
masked_image_latents = self.vae_encode("forward", (masked_image,))
|
||||
masked_image_latents = torch.from_numpy(masked_image_latents)
|
||||
if self.ondemand:
|
||||
self.unload_vae_encode()
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
return mask, masked_image_latents
|
||||
|
||||
def get_matched_noise(
|
||||
self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05
|
||||
):
|
||||
# helper fft routines that keep ortho normalization and auto-shift before and after fft
|
||||
def _fft2(data):
|
||||
if data.ndim > 2: # has channels
|
||||
out_fft = np.zeros(
|
||||
(data.shape[0], data.shape[1], data.shape[2]),
|
||||
dtype=np.complex128,
|
||||
)
|
||||
for c in range(data.shape[2]):
|
||||
c_data = data[:, :, c]
|
||||
out_fft[:, :, c] = np.fft.fft2(
|
||||
np.fft.fftshift(c_data), norm="ortho"
|
||||
)
|
||||
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
|
||||
else: # one channel
|
||||
out_fft = np.zeros(
|
||||
(data.shape[0], data.shape[1]), dtype=np.complex128
|
||||
)
|
||||
out_fft[:, :] = np.fft.fft2(
|
||||
np.fft.fftshift(data), norm="ortho"
|
||||
)
|
||||
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
|
||||
|
||||
return out_fft
|
||||
|
||||
def _ifft2(data):
|
||||
if data.ndim > 2: # has channels
|
||||
out_ifft = np.zeros(
|
||||
(data.shape[0], data.shape[1], data.shape[2]),
|
||||
dtype=np.complex128,
|
||||
)
|
||||
for c in range(data.shape[2]):
|
||||
c_data = data[:, :, c]
|
||||
out_ifft[:, :, c] = np.fft.ifft2(
|
||||
np.fft.fftshift(c_data), norm="ortho"
|
||||
)
|
||||
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
|
||||
else: # one channel
|
||||
out_ifft = np.zeros(
|
||||
(data.shape[0], data.shape[1]), dtype=np.complex128
|
||||
)
|
||||
out_ifft[:, :] = np.fft.ifft2(
|
||||
np.fft.fftshift(data), norm="ortho"
|
||||
)
|
||||
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
|
||||
|
||||
return out_ifft
|
||||
|
||||
def _get_gaussian_window(width, height, std=3.14, mode=0):
|
||||
window_scale_x = float(width / min(width, height))
|
||||
window_scale_y = float(height / min(width, height))
|
||||
|
||||
window = np.zeros((width, height))
|
||||
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
|
||||
for y in range(height):
|
||||
fy = (y / height * 2.0 - 1.0) * window_scale_y
|
||||
if mode == 0:
|
||||
window[:, y] = np.exp(-(x**2 + fy**2) * std)
|
||||
else:
|
||||
window[:, y] = (
|
||||
1 / ((x**2 + 1.0) * (fy**2 + 1.0))
|
||||
) ** (std / 3.14)
|
||||
|
||||
return window
|
||||
|
||||
def _get_masked_window_rgb(np_mask_grey, hardness=1.0):
|
||||
np_mask_rgb = np.zeros(
|
||||
(np_mask_grey.shape[0], np_mask_grey.shape[1], 3)
|
||||
)
|
||||
if hardness != 1.0:
|
||||
hardened = np_mask_grey[:] ** hardness
|
||||
else:
|
||||
hardened = np_mask_grey[:]
|
||||
for c in range(3):
|
||||
np_mask_rgb[:, :, c] = hardened[:]
|
||||
return np_mask_rgb
|
||||
|
||||
def _match_cumulative_cdf(source, template):
|
||||
src_values, src_unique_indices, src_counts = np.unique(
|
||||
source.ravel(), return_inverse=True, return_counts=True
|
||||
)
|
||||
tmpl_values, tmpl_counts = np.unique(
|
||||
template.ravel(), return_counts=True
|
||||
)
|
||||
|
||||
# calculate normalized quantiles for each array
|
||||
src_quantiles = np.cumsum(src_counts) / source.size
|
||||
tmpl_quantiles = np.cumsum(tmpl_counts) / template.size
|
||||
|
||||
interp_a_values = np.interp(
|
||||
src_quantiles, tmpl_quantiles, tmpl_values
|
||||
)
|
||||
return interp_a_values[src_unique_indices].reshape(source.shape)
|
||||
|
||||
def _match_histograms(image, reference):
|
||||
if image.ndim != reference.ndim:
|
||||
raise ValueError(
|
||||
"Image and reference must have the same number of channels."
|
||||
)
|
||||
|
||||
if image.shape[-1] != reference.shape[-1]:
|
||||
raise ValueError(
|
||||
"Number of channels in the input image and reference image must match!"
|
||||
)
|
||||
|
||||
matched = np.empty(image.shape, dtype=image.dtype)
|
||||
for channel in range(image.shape[-1]):
|
||||
matched_channel = _match_cumulative_cdf(
|
||||
image[..., channel], reference[..., channel]
|
||||
)
|
||||
matched[..., channel] = matched_channel
|
||||
|
||||
matched = matched.astype(np.float64, copy=False)
|
||||
return matched
|
||||
|
||||
width = _np_src_image.shape[0]
|
||||
height = _np_src_image.shape[1]
|
||||
num_channels = _np_src_image.shape[2]
|
||||
|
||||
np_src_image = _np_src_image[:] * (1.0 - np_mask_rgb)
|
||||
np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
|
||||
img_mask = np_mask_grey > 1e-6
|
||||
ref_mask = np_mask_grey < 1e-3
|
||||
|
||||
# rather than leave the masked area black, we get better results from fft by filling the average unmasked color
|
||||
windowed_image = _np_src_image * (
|
||||
1.0 - _get_masked_window_rgb(np_mask_grey)
|
||||
)
|
||||
windowed_image /= np.max(windowed_image)
|
||||
windowed_image += np.average(_np_src_image) * np_mask_rgb
|
||||
|
||||
src_fft = _fft2(
|
||||
windowed_image
|
||||
) # get feature statistics from masked src img
|
||||
src_dist = np.absolute(src_fft)
|
||||
src_phase = src_fft / src_dist
|
||||
|
||||
# create a generator with a static seed to make outpainting deterministic / only follow global seed
|
||||
rng = np.random.default_rng(0)
|
||||
|
||||
noise_window = _get_gaussian_window(
|
||||
width, height, mode=1
|
||||
) # start with simple gaussian noise
|
||||
noise_rgb = rng.random((width, height, num_channels))
|
||||
noise_grey = np.sum(noise_rgb, axis=2) / 3.0
|
||||
# the colorfulness of the starting noise is blended to greyscale with a parameter
|
||||
noise_rgb *= color_variation
|
||||
for c in range(num_channels):
|
||||
noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
|
||||
|
||||
noise_fft = _fft2(noise_rgb)
|
||||
for c in range(num_channels):
|
||||
noise_fft[:, :, c] *= noise_window
|
||||
noise_rgb = np.real(_ifft2(noise_fft))
|
||||
shaped_noise_fft = _fft2(noise_rgb)
|
||||
shaped_noise_fft[:, :, :] = (
|
||||
np.absolute(shaped_noise_fft[:, :, :]) ** 2
|
||||
* (src_dist**noise_q)
|
||||
* src_phase
|
||||
) # perform the actual shaping
|
||||
|
||||
# color_variation
|
||||
brightness_variation = 0.0
|
||||
contrast_adjusted_np_src = (
|
||||
_np_src_image[:] * (brightness_variation + 1.0)
|
||||
- brightness_variation * 2.0
|
||||
)
|
||||
|
||||
shaped_noise = np.real(_ifft2(shaped_noise_fft))
|
||||
shaped_noise -= np.min(shaped_noise)
|
||||
shaped_noise /= np.max(shaped_noise)
|
||||
shaped_noise[img_mask, :] = _match_histograms(
|
||||
shaped_noise[img_mask, :] ** 1.0,
|
||||
contrast_adjusted_np_src[ref_mask, :],
|
||||
)
|
||||
shaped_noise = (
|
||||
_np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
|
||||
)
|
||||
|
||||
matched_noise = shaped_noise[:]
|
||||
|
||||
return np.clip(matched_noise, 0.0, 1.0)
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
pixels,
|
||||
mask_blur,
|
||||
is_left,
|
||||
is_right,
|
||||
is_top,
|
||||
is_bottom,
|
||||
noise_q,
|
||||
color_variation,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
process_width = width
|
||||
process_height = height
|
||||
left = pixels if is_left else 0
|
||||
right = pixels if is_right else 0
|
||||
up = pixels if is_top else 0
|
||||
down = pixels if is_bottom else 0
|
||||
target_w = math.ceil((image.width + left + right) / 64) * 64
|
||||
target_h = math.ceil((image.height + up + down) / 64) * 64
|
||||
|
||||
if left > 0:
|
||||
left = left * (target_w - image.width) // (left + right)
|
||||
if right > 0:
|
||||
right = target_w - image.width - left
|
||||
if up > 0:
|
||||
up = up * (target_h - image.height) // (up + down)
|
||||
if down > 0:
|
||||
down = target_h - image.height - up
|
||||
|
||||
def expand(
|
||||
init_img,
|
||||
expand_pixels,
|
||||
is_left=False,
|
||||
is_right=False,
|
||||
is_top=False,
|
||||
is_bottom=False,
|
||||
):
|
||||
is_horiz = is_left or is_right
|
||||
is_vert = is_top or is_bottom
|
||||
pixels_horiz = expand_pixels if is_horiz else 0
|
||||
pixels_vert = expand_pixels if is_vert else 0
|
||||
|
||||
res_w = init_img.width + pixels_horiz
|
||||
res_h = init_img.height + pixels_vert
|
||||
process_res_w = math.ceil(res_w / 64) * 64
|
||||
process_res_h = math.ceil(res_h / 64) * 64
|
||||
|
||||
img = Image.new("RGB", (process_res_w, process_res_h))
|
||||
img.paste(
|
||||
init_img,
|
||||
(pixels_horiz if is_left else 0, pixels_vert if is_top else 0),
|
||||
)
|
||||
|
||||
msk = Image.new("RGB", (process_res_w, process_res_h), "white")
|
||||
draw = ImageDraw.Draw(msk)
|
||||
draw.rectangle(
|
||||
(
|
||||
expand_pixels + mask_blur if is_left else 0,
|
||||
expand_pixels + mask_blur if is_top else 0,
|
||||
msk.width - expand_pixels - mask_blur
|
||||
if is_right
|
||||
else res_w,
|
||||
msk.height - expand_pixels - mask_blur
|
||||
if is_bottom
|
||||
else res_h,
|
||||
),
|
||||
fill="black",
|
||||
)
|
||||
|
||||
np_image = (np.asarray(img) / 255.0).astype(np.float64)
|
||||
np_mask = (np.asarray(msk) / 255.0).astype(np.float64)
|
||||
noised = self.get_matched_noise(
|
||||
np_image, np_mask, noise_q, color_variation
|
||||
)
|
||||
output_image = Image.fromarray(
|
||||
np.clip(noised * 255.0, 0.0, 255.0).astype(np.uint8),
|
||||
mode="RGB",
|
||||
)
|
||||
|
||||
target_width = (
|
||||
min(width, init_img.width + pixels_horiz)
|
||||
if is_horiz
|
||||
else img.width
|
||||
)
|
||||
target_height = (
|
||||
min(height, init_img.height + pixels_vert)
|
||||
if is_vert
|
||||
else img.height
|
||||
)
|
||||
crop_region = (
|
||||
0 if is_left else output_image.width - target_width,
|
||||
0 if is_top else output_image.height - target_height,
|
||||
target_width if is_left else output_image.width,
|
||||
target_height if is_top else output_image.height,
|
||||
)
|
||||
mask_to_process = msk.crop(crop_region)
|
||||
image_to_process = output_image.crop(crop_region)
|
||||
|
||||
# Preprocess mask and image
|
||||
mask, masked_image = self.prepare_mask_and_masked_image(
|
||||
image_to_process, mask_to_process, mask_blur, width, height
|
||||
)
|
||||
|
||||
# Prepare mask latent variables
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask=mask,
|
||||
masked_image=masked_image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
mask=mask,
|
||||
masked_image_latents=masked_image_latents,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
|
||||
res_img = all_imgs[0].resize(
|
||||
(image_to_process.width, image_to_process.height)
|
||||
)
|
||||
output_image.paste(
|
||||
res_img,
|
||||
(
|
||||
0 if is_left else output_image.width - res_img.width,
|
||||
0 if is_top else output_image.height - res_img.height,
|
||||
),
|
||||
)
|
||||
output_image = output_image.crop((0, 0, res_w, res_h))
|
||||
|
||||
return output_image
|
||||
|
||||
img = image.resize((width, height))
|
||||
if left > 0:
|
||||
img = expand(img, left, is_left=True)
|
||||
if right > 0:
|
||||
img = expand(img, right, is_right=True)
|
||||
if up > 0:
|
||||
img = expand(img, up, is_top=True)
|
||||
if down > 0:
|
||||
img = expand(img, down, is_bottom=True)
|
||||
|
||||
return [img]
|
||||
@@ -1,603 +0,0 @@
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from random import randint
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
controlnet_hint_conversion,
|
||||
controlnet_hint_reshaping,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
start_profiling,
|
||||
end_profiling,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
|
||||
|
||||
class StencilPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
controlnet_names: list[str],
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.controlnet = [None] * len(controlnet_names)
|
||||
self.controlnet_512 = [None] * len(controlnet_names)
|
||||
self.controlnet_id = [str] * len(controlnet_names)
|
||||
self.controlnet_512_id = [str] * len(controlnet_names)
|
||||
self.controlnet_names = controlnet_names
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
if self.vae_encode is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
else:
|
||||
try:
|
||||
self.vae_encode = get_vae_encode()
|
||||
except:
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
|
||||
def unload_vae_encode(self):
|
||||
del self.vae_encode
|
||||
self.vae_encode = None
|
||||
|
||||
def load_controlnet(self, index, model_name):
|
||||
if model_name is None:
|
||||
return
|
||||
if (
|
||||
self.controlnet[index] is not None
|
||||
and self.controlnet_id[index] is not None
|
||||
and self.controlnet_id[index] == model_name
|
||||
):
|
||||
return
|
||||
self.controlnet_id[index] = model_name
|
||||
self.controlnet[index] = self.sd_model.controlnet(model_name)
|
||||
|
||||
def unload_controlnet(self, index):
|
||||
del self.controlnet[index]
|
||||
self.controlnet_id[index] = None
|
||||
self.controlnet[index] = None
|
||||
|
||||
def load_controlnet_512(self, index, model_name):
|
||||
if (
|
||||
self.controlnet_512[index] is not None
|
||||
and self.controlnet_512_id[index] == model_name
|
||||
):
|
||||
return
|
||||
self.controlnet_512_id[index] = model_name
|
||||
self.controlnet_512[index] = self.sd_model.controlnet(
|
||||
model_name, use_large=True
|
||||
)
|
||||
|
||||
def unload_controlnet_512(self, index):
|
||||
del self.controlnet_512[index]
|
||||
self.controlnet_512_id[index] = None
|
||||
self.controlnet_512[index] = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
dtype,
|
||||
resample_type,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre-process image
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
image = image.resize((width, height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
|
||||
# set scheduler steps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
init_timestep = min(
|
||||
int(num_inference_steps * strength), num_inference_steps
|
||||
)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
# timesteps reduced as per strength
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
# new number of steps to be used as per strength will be
|
||||
# num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# image encode
|
||||
latents = self.encode_image((image_arr,))
|
||||
latents = torch.from_numpy(latents).to(dtype)
|
||||
# add noise to data
|
||||
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents, noise, timesteps[0].repeat(1)
|
||||
)
|
||||
|
||||
return latents, timesteps
|
||||
|
||||
def produce_stencil_latents(
|
||||
self,
|
||||
latents,
|
||||
text_embeddings,
|
||||
guidance_scale,
|
||||
total_timesteps,
|
||||
dtype,
|
||||
cpu_scheduling,
|
||||
stencil_hints=[None],
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
control_mode="Balanced", # Prompt, Balanced, or Controlnet
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
step_time_sum = 0
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
assert control_mode in ["Prompt", "Balanced", "Controlnet"]
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
|
||||
for i, name in enumerate(self.controlnet_names):
|
||||
use_names = []
|
||||
if name is not None:
|
||||
use_names.append(name)
|
||||
else:
|
||||
continue
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_controlnet(i, name)
|
||||
else:
|
||||
self.load_controlnet_512(i, name)
|
||||
self.controlnet_names = use_names
|
||||
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype)
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, t)
|
||||
if mask is not None and masked_image_latents is not None:
|
||||
latent_model_input = torch.cat(
|
||||
[
|
||||
torch.from_numpy(np.asarray(latent_model_input)),
|
||||
mask,
|
||||
masked_image_latents,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype)
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
if not torch.is_tensor(latent_model_input):
|
||||
latent_model_input_1 = torch.from_numpy(
|
||||
np.asarray(latent_model_input)
|
||||
).to(dtype)
|
||||
else:
|
||||
latent_model_input_1 = latent_model_input
|
||||
|
||||
# Multicontrolnet
|
||||
width = latent_model_input_1.shape[2]
|
||||
height = latent_model_input_1.shape[3]
|
||||
dtype = latent_model_input_1.dtype
|
||||
control_acc = (
|
||||
[torch.zeros((2, 320, height, width), dtype=dtype)] * 3
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 320, int(height / 2), int(width / 2)), dtype=dtype
|
||||
)
|
||||
]
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 640, int(height / 2), int(width / 2)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 2
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 640, int(height / 4), int(width / 4)), dtype=dtype
|
||||
)
|
||||
]
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 1280, int(height / 4), int(width / 4)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 2
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 1280, int(height / 8), int(width / 8)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 4
|
||||
)
|
||||
for i, controlnet_hint in enumerate(stencil_hints):
|
||||
if controlnet_hint is None:
|
||||
pass
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
control = self.controlnet[i](
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
*control_acc,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
control = self.controlnet_512[i](
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
*control_acc,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
control_acc = control[13:]
|
||||
control = control[:13]
|
||||
|
||||
timestep = timestep.detach().numpy()
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
|
||||
|
||||
dtype = latents.dtype
|
||||
if control_mode == "Balanced":
|
||||
control_scale = [
|
||||
torch.tensor(1.0, dtype=dtype) for _ in range(len(control))
|
||||
]
|
||||
elif control_mode == "Prompt":
|
||||
control_scale = [
|
||||
torch.tensor(0.825**x, dtype=dtype)
|
||||
for x in range(len(control))
|
||||
]
|
||||
elif control_mode == "Controlnet":
|
||||
control_scale = [
|
||||
torch.tensor(float(guidance_scale), dtype=dtype)
|
||||
for _ in range(len(control))
|
||||
]
|
||||
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
control_scale[0],
|
||||
control_scale[1],
|
||||
control_scale[2],
|
||||
control_scale[3],
|
||||
control_scale[4],
|
||||
control_scale[5],
|
||||
control_scale[6],
|
||||
control_scale[7],
|
||||
control_scale[8],
|
||||
control_scale[9],
|
||||
control_scale[10],
|
||||
control_scale[11],
|
||||
control_scale[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
control_scale[0],
|
||||
control_scale[1],
|
||||
control_scale[2],
|
||||
control_scale[3],
|
||||
control_scale[4],
|
||||
control_scale[5],
|
||||
control_scale[6],
|
||||
control_scale[7],
|
||||
control_scale[8],
|
||||
control_scale[9],
|
||||
control_scale[10],
|
||||
control_scale[11],
|
||||
control_scale[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
noise_pred = torch.from_numpy(noise_pred.to_host())
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents
|
||||
).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents)
|
||||
|
||||
latent_history.append(latents)
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
# self.log += (
|
||||
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
for i in range(len(self.controlnet_names)):
|
||||
self.unload_controlnet(i)
|
||||
self.unload_controlnet_512(i)
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
if not return_all_latents:
|
||||
return latents
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def encode_image(self, input_image):
|
||||
self.load_vae_encode()
|
||||
vae_encode_start = time.time()
|
||||
latents = self.vae_encode("forward", input_image)
|
||||
vae_inf_time = (time.time() - vae_encode_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_vae_encode()
|
||||
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
return latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
stencils,
|
||||
stencil_images,
|
||||
resample_type,
|
||||
control_mode,
|
||||
preprocessed_hints,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# controlnet_hint = controlnet_hint_conversion(
|
||||
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
|
||||
# )
|
||||
stencil_hints = []
|
||||
self.sd_model.stencils = stencils
|
||||
for i, hint in enumerate(preprocessed_hints):
|
||||
if hint is not None:
|
||||
hint = controlnet_hint_reshaping(
|
||||
hint,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
stencil_hints.append(hint)
|
||||
|
||||
for i, stencil in enumerate(stencils):
|
||||
if stencil == None:
|
||||
continue
|
||||
if len(stencil_hints) > i:
|
||||
if stencil_hints[i] is not None:
|
||||
print(f"Using preprocessed controlnet hint for {stencil}")
|
||||
continue
|
||||
image = stencil_images[i]
|
||||
stencil_hints.append(
|
||||
controlnet_hint_conversion(
|
||||
image,
|
||||
stencil,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
)
|
||||
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
if image is not None:
|
||||
# Prepare input image latent
|
||||
init_latents, final_timesteps = self.prepare_image_latents(
|
||||
image=image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
else:
|
||||
# Prepare initial latent.
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
final_timesteps = self.scheduler.timesteps
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_stencil_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=final_timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
control_mode=control_mode,
|
||||
stencil_hints=stencil_hints,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
@@ -1,159 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
|
||||
|
||||
class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in range(0, latents.shape[0], batch_size):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
@@ -1,220 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from typing import Union
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
is_fp32_vae: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.is_fp32_vae = is_fp32_vae
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype
|
||||
):
|
||||
add_time_ids = list(
|
||||
original_size + crops_coords_top_left + target_size
|
||||
)
|
||||
|
||||
# self.unet.config.addition_time_embed_dim IS 256.
|
||||
# self.text_encoder_2.config.projection_dim IS 1280.
|
||||
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
|
||||
expected_add_embed_dim = 2816
|
||||
# self.unet.add_embedding.linear_1.in_features IS 2816.
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||
)
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents.
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings.
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt_sdxl(
|
||||
prompt=prompts,
|
||||
num_images_per_prompt=1,
|
||||
do_classifier_free_guidance=True,
|
||||
negative_prompt=neg_prompts,
|
||||
)
|
||||
|
||||
# Prepare timesteps.
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Prepare added time ids & embeddings.
|
||||
original_size = (height, width)
|
||||
target_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat(
|
||||
[negative_prompt_embeds, prompt_embeds], dim=0
|
||||
)
|
||||
add_text_embeds = torch.cat(
|
||||
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
|
||||
)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds
|
||||
add_text_embeds = add_text_embeds.to(dtype)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(dtype)
|
||||
prompt_embeds = prompt_embeds.to(dtype)
|
||||
add_time_ids = add_time_ids.to(dtype)
|
||||
|
||||
# Get Image latents.
|
||||
latents = self.produce_img_latents_sdxl(
|
||||
init_latents,
|
||||
timesteps,
|
||||
add_text_embeds,
|
||||
add_time_ids,
|
||||
prompt_embeds,
|
||||
cpu_scheduling,
|
||||
guidance_scale,
|
||||
dtype,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images.
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in range(0, latents.shape[0], batch_size):
|
||||
imgs = self.decode_latents_sdxl(
|
||||
latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
@@ -1,357 +0,0 @@
|
||||
import inspect
|
||||
import torch
|
||||
import time
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_IDLE,
|
||||
SD_STATE_CANCEL,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
start_profiling,
|
||||
end_profiling,
|
||||
)
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
|
||||
image = [np.array(i.resize((w, h)))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
class UpscalerPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
low_res_scheduler: Union[
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.low_res_scheduler = low_res_scheduler
|
||||
self.status = SD_STATE_IDLE
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
|
||||
latents = 1 / 0.08333 * (latents.float())
|
||||
latents_numpy = latents
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
|
||||
profile_device = start_profiling(file_path="vae.rdc")
|
||||
vae_start = time.time()
|
||||
images = self.vae("forward", (latents_numpy,))
|
||||
vae_inf_time = (time.time() - vae_start) * 1000
|
||||
end_profiling(profile_device)
|
||||
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
images = torch.from_numpy(images)
|
||||
images = (images.detach().cpu() * 255.0).numpy()
|
||||
images = images.round()
|
||||
|
||||
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
pil_images = [Image.fromarray(image) for image in images.numpy()]
|
||||
return pil_images
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def produce_img_latents(
|
||||
self,
|
||||
latents,
|
||||
image,
|
||||
text_embeddings,
|
||||
guidance_scale,
|
||||
noise_level,
|
||||
total_timesteps,
|
||||
dtype,
|
||||
cpu_scheduling,
|
||||
extra_step_kwargs,
|
||||
return_all_latents=False,
|
||||
):
|
||||
step_time_sum = 0
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.status = SD_STATE_IDLE
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
latent_model_input = torch.cat([latent_model_input, image], dim=1)
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
noise_level,
|
||||
),
|
||||
)
|
||||
else:
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
noise_level,
|
||||
),
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
if cpu_scheduling:
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
)
|
||||
|
||||
latent_history.append(latents)
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
# self.log += (
|
||||
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.status == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
if not return_all_latents:
|
||||
return latents
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
noise_level,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image).to(dtype)
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long)
|
||||
noise = torch.randn(
|
||||
image.shape,
|
||||
generator=generator,
|
||||
).to(dtype)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
image = torch.cat([image] * 2)
|
||||
noise_level = torch.cat([noise_level] * image.shape[0])
|
||||
|
||||
height, width = image.shape[2:]
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
eta = 0.0
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
# guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
image=image,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
noise_level=noise_level,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
extra_step_kwargs=extra_step_kwargs,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +0,0 @@
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
|
||||
@@ -1,128 +0,0 @@
|
||||
from diffusers import (
|
||||
LCMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDPMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
|
||||
|
||||
def get_schedulers(model_id):
|
||||
# TODO: Robust scheduler setup on pipeline creation -- if we don't
|
||||
# set batch_size here, the SHARK schedulers will
|
||||
# compile with batch size = 1 regardless of whether the model
|
||||
# outputs latents of a larger batch size, e.g. SDXL.
|
||||
# However, obviously, searching for whether the base model ID
|
||||
# contains "xl" is not very robust.
|
||||
|
||||
batch_size = 2 if "xl" in model_id.lower() else 1
|
||||
|
||||
schedulers = dict()
|
||||
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep++"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistepKarras"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistepKarras++"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
algorithm_type="dpmsolver++",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"EulerAncestralDiscrete"
|
||||
] = EulerAncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"SharkEulerDiscrete"
|
||||
] = SharkEulerDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"SharkEulerAncestralDiscrete"
|
||||
] = SharkEulerAncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverSinglestep"
|
||||
] = DPMSolverSinglestepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"KDPM2AncestralDiscrete"
|
||||
] = KDPM2AncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["SharkEulerDiscrete"].compile(batch_size)
|
||||
schedulers["SharkEulerAncestralDiscrete"].compile(batch_size)
|
||||
return schedulers
|
||||
@@ -1,251 +0,0 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_shark_model,
|
||||
args,
|
||||
)
|
||||
import torch
|
||||
|
||||
|
||||
class SharkEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
beta_start,
|
||||
beta_end,
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
timestep_spacing,
|
||||
steps_offset,
|
||||
)
|
||||
# TODO: make it dynamic so we dont have to worry about batch size
|
||||
self.batch_size = None
|
||||
self.init_input_shape = None
|
||||
|
||||
def compile(self, batch_size=1):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
device = args.device.split(":", 1)[0].strip()
|
||||
self.batch_size = batch_size
|
||||
|
||||
model_input = {
|
||||
"eulera": {
|
||||
"output": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"latent": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"sigma_from": torch.tensor(1).to(torch.float32),
|
||||
"sigma_to": torch.tensor(1).to(torch.float32),
|
||||
"noise": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
example_latent = model_input["eulera"]["latent"]
|
||||
example_output = model_input["eulera"]["output"]
|
||||
example_noise = model_input["eulera"]["noise"]
|
||||
if args.precision == "fp16":
|
||||
example_latent = example_latent.half()
|
||||
example_output = example_output.half()
|
||||
example_noise = example_noise.half()
|
||||
example_sigma = model_input["eulera"]["sigma"]
|
||||
example_sigma_from = model_input["eulera"]["sigma_from"]
|
||||
example_sigma_to = model_input["eulera"]["sigma_to"]
|
||||
|
||||
class ScalingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepEpsilonModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, noise_pred, latent, sigma, sigma_from, sigma_to, noise
|
||||
):
|
||||
sigma_up = (
|
||||
sigma_to**2
|
||||
* (sigma_from**2 - sigma_to**2)
|
||||
/ sigma_from**2
|
||||
) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
dt = sigma_down - sigma
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
prev_sample = latent + derivative * dt
|
||||
return prev_sample + noise * sigma_up
|
||||
|
||||
class SchedulerStepVPredictionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, noise_pred, sigma, sigma_from, sigma_to, latent, noise
|
||||
):
|
||||
sigma_up = (
|
||||
sigma_to**2
|
||||
* (sigma_from**2 - sigma_to**2)
|
||||
/ sigma_from**2
|
||||
) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
dt = sigma_down - sigma
|
||||
pred_original_sample = noise_pred * (
|
||||
-sigma / (sigma**2 + 1) ** 0.5
|
||||
) + (latent / (sigma**2 + 1))
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
prev_sample = latent + derivative * dt
|
||||
return prev_sample + noise * sigma_up
|
||||
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model, _ = compile_through_fx(
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
extended_model_name=f"euler_a_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
pred_type_model_dict = {
|
||||
"epsilon": SchedulerStepEpsilonModel(),
|
||||
"v_prediction": SchedulerStepVPredictionModel(),
|
||||
}
|
||||
step_model = pred_type_model_dict[self.config.prediction_type]
|
||||
self.step_model, _ = compile_through_fx(
|
||||
step_model,
|
||||
(
|
||||
example_output,
|
||||
example_latent,
|
||||
example_sigma,
|
||||
example_sigma_from,
|
||||
example_sigma_to,
|
||||
example_noise,
|
||||
),
|
||||
extended_model_name=f"euler_a_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
if args.import_mlir:
|
||||
_import(self)
|
||||
|
||||
else:
|
||||
try:
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_a_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_a_step_"
|
||||
+ self.config.prediction_type
|
||||
+ args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"failed to download model, falling back and using import_mlir"
|
||||
)
|
||||
args.import_mlir = True
|
||||
_import(self)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
sigma = self.sigmas[self.step_index]
|
||||
return self.scaling_model(
|
||||
"forward",
|
||||
(
|
||||
sample,
|
||||
sigma,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(
|
||||
self,
|
||||
noise_pred,
|
||||
timestep,
|
||||
latent,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
step_inputs = []
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
sigma_from = self.sigmas[self.step_index]
|
||||
sigma_to = self.sigmas[self.step_index + 1]
|
||||
noise = randn_tensor(
|
||||
torch.Size(noise_pred.shape),
|
||||
dtype=torch.float16,
|
||||
device="cpu",
|
||||
generator=generator,
|
||||
)
|
||||
step_inputs = [
|
||||
noise_pred,
|
||||
latent,
|
||||
sigma,
|
||||
sigma_from,
|
||||
sigma_to,
|
||||
noise,
|
||||
]
|
||||
# TODO: deal with dynamic inputs in turbine flow.
|
||||
# update step index since we're done with the variable and will return with compiled module output.
|
||||
self._step_index += 1
|
||||
|
||||
if noise_pred.shape[0] < self.batch_size:
|
||||
for i in [0, 1, 5]:
|
||||
try:
|
||||
step_inputs[i] = torch.tensor(step_inputs[i])
|
||||
except:
|
||||
step_inputs[i] = torch.tensor(step_inputs[i].to_host())
|
||||
step_inputs[i] = torch.cat(
|
||||
(step_inputs[i], step_inputs[i]), axis=0
|
||||
)
|
||||
return self.step_model(
|
||||
"forward",
|
||||
tuple(step_inputs),
|
||||
send_to_host=True,
|
||||
)
|
||||
|
||||
return self.step_model(
|
||||
"forward",
|
||||
tuple(step_inputs),
|
||||
send_to_host=False,
|
||||
)
|
||||
@@ -1,245 +0,0 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_shark_model,
|
||||
args,
|
||||
)
|
||||
import torch
|
||||
|
||||
|
||||
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
interpolation_type: str = "linear",
|
||||
use_karras_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
timestep_type: str = "discrete",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
beta_start,
|
||||
beta_end,
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
interpolation_type,
|
||||
use_karras_sigmas,
|
||||
sigma_min,
|
||||
sigma_max,
|
||||
timestep_spacing,
|
||||
timestep_type,
|
||||
steps_offset,
|
||||
)
|
||||
# TODO: make it dynamic so we dont have to worry about batch size
|
||||
self.batch_size = 1
|
||||
|
||||
def compile(self, batch_size=1):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
device = args.device.split(":", 1)[0].strip()
|
||||
self.batch_size = batch_size
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
"latent": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"output": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"dt": torch.tensor(1).to(torch.float32),
|
||||
},
|
||||
}
|
||||
|
||||
example_latent = model_input["euler"]["latent"]
|
||||
example_output = model_input["euler"]["output"]
|
||||
if args.precision == "fp16":
|
||||
example_latent = example_latent.half()
|
||||
example_output = example_output.half()
|
||||
example_sigma = model_input["euler"]["sigma"]
|
||||
example_dt = model_input["euler"]["dt"]
|
||||
|
||||
class ScalingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepEpsilonModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma_hat, latent, dt):
|
||||
pred_original_sample = latent - sigma_hat * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma_hat
|
||||
return latent + derivative * dt
|
||||
|
||||
class SchedulerStepSampleModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma_hat, latent, dt):
|
||||
pred_original_sample = noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma_hat
|
||||
return latent + derivative * dt
|
||||
|
||||
class SchedulerStepVPredictionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma, latent, dt):
|
||||
pred_original_sample = noise_pred * (
|
||||
-sigma / (sigma**2 + 1) ** 0.5
|
||||
) + (latent / (sigma**2 + 1))
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
return latent + derivative * dt
|
||||
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model, _ = compile_through_fx(
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
extended_model_name=f"euler_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
pred_type_model_dict = {
|
||||
"epsilon": SchedulerStepEpsilonModel(),
|
||||
"v_prediction": SchedulerStepVPredictionModel(),
|
||||
"sample": SchedulerStepSampleModel(),
|
||||
"original_sample": SchedulerStepSampleModel(),
|
||||
}
|
||||
step_model = pred_type_model_dict[self.config.prediction_type]
|
||||
self.step_model, _ = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
extended_model_name=f"euler_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
if args.import_mlir:
|
||||
_import(self)
|
||||
|
||||
else:
|
||||
try:
|
||||
step_model_type = (
|
||||
"sample"
|
||||
if "sample" in self.config.prediction_type
|
||||
else self.config.prediction_type
|
||||
)
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_step_" + step_model_type + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"failed to download model, falling back and using import_mlir"
|
||||
)
|
||||
args.import_mlir = True
|
||||
_import(self)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
sigma = self.sigmas[self.step_index]
|
||||
return self.scaling_model(
|
||||
"forward",
|
||||
(
|
||||
sample,
|
||||
sigma,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(
|
||||
self,
|
||||
noise_pred,
|
||||
timestep,
|
||||
latent,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
gamma = (
|
||||
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
|
||||
if s_tmin <= sigma <= s_tmax
|
||||
else 0.0
|
||||
)
|
||||
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
noise_pred = (
|
||||
torch.from_numpy(noise_pred)
|
||||
if isinstance(noise_pred, np.ndarray)
|
||||
else noise_pred
|
||||
)
|
||||
|
||||
noise = randn_tensor(
|
||||
torch.Size(noise_pred.shape),
|
||||
dtype=torch.float16,
|
||||
device="cpu",
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
eps = noise * s_noise
|
||||
|
||||
if gamma > 0:
|
||||
latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
if self.config.prediction_type == "v_prediction":
|
||||
sigma_hat = sigma
|
||||
|
||||
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
return self.step_model(
|
||||
"forward",
|
||||
(
|
||||
noise_pred,
|
||||
sigma_hat,
|
||||
latent,
|
||||
dt,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
@@ -1,49 +0,0 @@
|
||||
from apps.stable_diffusion.src.utils.profiler import (
|
||||
start_profiling,
|
||||
end_profiling,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.resources import (
|
||||
prompt_examples,
|
||||
models_db,
|
||||
base_models,
|
||||
opt_flags,
|
||||
resource_path,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
|
||||
controlnet_hint_conversion,
|
||||
controlnet_hint_reshaping,
|
||||
get_stencil_model_id,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
get_shark_model,
|
||||
compile_through_fx,
|
||||
set_iree_runtime_flags,
|
||||
map_device_to_name_path,
|
||||
set_init_device_flags,
|
||||
get_available_devices,
|
||||
get_opt_flags,
|
||||
preprocessCKPT,
|
||||
convert_original_vae,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
sanitize_seed,
|
||||
parse_seed_input,
|
||||
batch_seeds,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
get_generation_text_info,
|
||||
update_lora_weight,
|
||||
resize_stencil,
|
||||
_compile_module,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint
|
||||
from apps.stable_diffusion.src.utils.resamplers import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
@@ -1,42 +0,0 @@
|
||||
import re
|
||||
import requests
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_civitai_checkpoint(url: str):
|
||||
with requests.get(url, allow_redirects=True, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# civitai api returns the filename in the content disposition
|
||||
base_filename = re.findall(
|
||||
'"([^"]*)"', response.headers["Content-Disposition"]
|
||||
)[0]
|
||||
destination_path = (
|
||||
Path.cwd() / (args.ckpt_dir or "models") / base_filename
|
||||
)
|
||||
|
||||
# we don't have this model downloaded yet
|
||||
if not destination_path.is_file():
|
||||
print(
|
||||
f"downloading civitai model from {url} to {destination_path}"
|
||||
)
|
||||
|
||||
size = int(response.headers["content-length"], 0)
|
||||
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=65536):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
# we already have this model downloaded
|
||||
else:
|
||||
print(f"civitai model already downloaded to {destination_path}")
|
||||
|
||||
response.close()
|
||||
return destination_path.as_posix()
|
||||
@@ -1,20 +0,0 @@
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
|
||||
# Helper function to profile the vulkan device.
|
||||
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
|
||||
from shark.parser import shark_args
|
||||
|
||||
if shark_args.vulkan_debug_utils and "vulkan" in args.device:
|
||||
import iree
|
||||
|
||||
print(f"Profiling and saving to {file_path}.")
|
||||
vulkan_device = iree.runtime.get_device(args.device)
|
||||
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
|
||||
return vulkan_device
|
||||
return None
|
||||
|
||||
|
||||
def end_profiling(device):
|
||||
if device:
|
||||
return device.end_profiling()
|
||||
@@ -1,12 +0,0 @@
|
||||
import PIL.Image as Image
|
||||
|
||||
resamplers = {
|
||||
"Lanczos": Image.Resampling.LANCZOS,
|
||||
"Nearest Neighbor": Image.Resampling.NEAREST,
|
||||
"Bilinear": Image.Resampling.BILINEAR,
|
||||
"Bicubic": Image.Resampling.BICUBIC,
|
||||
"Hamming": Image.Resampling.HAMMING,
|
||||
"Box": Image.Resampling.BOX,
|
||||
}
|
||||
|
||||
resampler_list = resamplers.keys()
|
||||
@@ -1,37 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def get_json_file(path):
|
||||
json_var = []
|
||||
loc_json = resource_path(path)
|
||||
if os.path.exists(loc_json):
|
||||
with open(loc_json, encoding="utf-8") as fopen:
|
||||
json_var = json.load(fopen)
|
||||
|
||||
if not json_var:
|
||||
print(f"Unable to fetch {path}")
|
||||
|
||||
return json_var
|
||||
|
||||
|
||||
# TODO: This shouldn't be called from here, every time the file imports
|
||||
# it will run all the global vars.
|
||||
prompt_examples = get_json_file("resources/prompts.json")
|
||||
models_db = get_json_file("resources/model_db.json")
|
||||
|
||||
# The base_model contains the input configuration for the different
|
||||
# models and also helps in providing information for the variants.
|
||||
base_models = get_json_file("resources/base_model.json")
|
||||
|
||||
# Contains optimization flags for different models.
|
||||
opt_flags = get_json_file("resources/opt_flags.json")
|
||||
@@ -1,495 +0,0 @@
|
||||
{
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"2*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
},
|
||||
"sdxl_clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"1*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"height","width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae_upscaler": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
}
|
||||
},
|
||||
"unet": {
|
||||
"stabilityai/stable-diffusion-2-1": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
1024
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"CompVis/stable-diffusion-v1-4": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-2-inpainting": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
9,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
1024
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"runwayml/stable-diffusion-inpainting": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
9,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-x4-upscaler": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
7,
|
||||
"8*height",
|
||||
"8*width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
1024
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"noise_level": {
|
||||
"shape": [2],
|
||||
"dtype": "i64"
|
||||
}
|
||||
},
|
||||
"stabilityai/sdxl-turbo": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"prompt_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
2048
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"text_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
1280
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"time_ids": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
6
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"prompt_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
2048
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"text_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
1280
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"time_ids": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
6
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
}
|
||||
},
|
||||
"stencil_adapter": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"controlnet_hint": {
|
||||
"shape": [1, 3, "8*height", "8*width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc1": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc2": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc3": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc4": {
|
||||
"shape": [2, 320, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc5": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc6": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc7": {
|
||||
"shape": [2, 640, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc8": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc9": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc10": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc11": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc12": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc13": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stencil_unet": {
|
||||
"CompVis/stable-diffusion-v1-4": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control1": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control2": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control3": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control4": {
|
||||
"shape": [2, 320, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control5": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control6": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control7": {
|
||||
"shape": [2, 640, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control8": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control9": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control10": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control11": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control12": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control13": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale1": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale2": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale3": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale4": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale5": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale6": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale7": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale8": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale9": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale10": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale11": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale12": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale13": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
[
|
||||
{
|
||||
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
|
||||
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
|
||||
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
|
||||
"stablediffusion/inpaint_v1":"runwayml/stable-diffusion-inpainting",
|
||||
"stablediffusion/inpaint_v2":"stabilityai/stable-diffusion-2-inpainting",
|
||||
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
|
||||
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
|
||||
"openjourney/v1_4":"prompthero/openjourney",
|
||||
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
|
||||
},
|
||||
{
|
||||
"stablediffusion/fp16":"fp16",
|
||||
"stablediffusion/fp32":"main",
|
||||
"anythingv3/fp16":"diffusers",
|
||||
"anythingv3/fp32":"diffusers",
|
||||
"analogdiffusion/fp16":"main",
|
||||
"analogdiffusion/fp32":"main",
|
||||
"openjourney/fp16":"main",
|
||||
"openjourney/fp32":"main"
|
||||
}
|
||||
]
|
||||
@@ -1,19 +0,0 @@
|
||||
[
|
||||
{
|
||||
"stablediffusion/untuned":"gs://shark_tank/nightly"
|
||||
},
|
||||
{
|
||||
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
|
||||
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
|
||||
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
|
||||
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan",
|
||||
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
|
||||
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan"
|
||||
}
|
||||
]
|
||||
@@ -1,88 +0,0 @@
|
||||
{
|
||||
"unet": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": []
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": []
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [],
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [],
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
"--iree-opt-data-tiling=False"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
"--iree-opt-data-tiling=False"
|
||||
]
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
"--iree-opt-data-tiling=False"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
"--iree-opt-data-tiling=False"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
[["A high tech solarpunk utopia in the Amazon rainforest"],
|
||||
["Astrophotography, the shark nebula, nebula with a tiny shark-like cloud in the middle in the middle, hubble telescope, vivid colors"],
|
||||
["A pikachu fine dining with a view to the Eiffel Tower"],
|
||||
["A mecha robot in a favela in expressionist style"],
|
||||
["an insect robot preparing a delicious meal"],
|
||||
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
|
||||
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
|
||||
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
|
||||
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"],
|
||||
["A photo of a beach, sunset, calm, beautiful landscape, waves, water"],
|
||||
["(a large body of water with snowy mountains in the background), (fog, foggy, rolling fog), (clouds, cloudy, rolling clouds), dramatic sky and landscape, extraordinary landscape, (beautiful snow capped mountain background), (forest, dirt path)"],
|
||||
["a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smokes coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"]]
|
||||
@@ -1,300 +0,0 @@
|
||||
import os
|
||||
import io
|
||||
from shark.model_annotation import model_annotation, create_context
|
||||
from shark.iree_utils._common import iree_target_map, run_cmd
|
||||
from shark.shark_downloader import (
|
||||
download_model,
|
||||
download_public_file,
|
||||
WORKDIR,
|
||||
)
|
||||
from shark.parser import shark_args
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
|
||||
def get_device():
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else args.device.split("://")[0]
|
||||
)
|
||||
return device
|
||||
|
||||
|
||||
def get_device_args():
|
||||
device = get_device()
|
||||
device_spec_args = []
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
gpu_flags = get_iree_gpu_args()
|
||||
for flag in gpu_flags:
|
||||
device_spec_args.append(flag)
|
||||
elif device == "vulkan":
|
||||
device_spec_args.append(
|
||||
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
|
||||
)
|
||||
return device, device_spec_args
|
||||
|
||||
|
||||
# Download the model (Unet or VAE fp16) from shark_tank
|
||||
def load_model_from_tank():
|
||||
from apps.stable_diffusion.src.models import (
|
||||
get_params,
|
||||
get_variant_version,
|
||||
)
|
||||
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
bucket_key = f"{variant}/untuned"
|
||||
if args.annotation_model == "unet":
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
|
||||
elif args.annotation_model == "vae":
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, args.annotation_model, "untuned", args.precision
|
||||
)
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=bucket,
|
||||
frontend="torch",
|
||||
)
|
||||
return mlir_model, model_name
|
||||
|
||||
|
||||
# Download the tuned config files from shark_tank
|
||||
def load_winograd_configs():
|
||||
device = get_device()
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
config_name = f"{args.annotation_model}_winograd_{device}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
if not os.path.exists(WORKDIR):
|
||||
os.mkdir(WORKDIR)
|
||||
winograd_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading Winograd config file from ", winograd_config_dir)
|
||||
download_public_file(full_gs_url, winograd_config_dir, True)
|
||||
return winograd_config_dir
|
||||
|
||||
|
||||
def load_lower_configs(base_model_id=None):
|
||||
from apps.stable_diffusion.src.models import get_variant_version
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
)
|
||||
|
||||
if not base_model_id:
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
else:
|
||||
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
|
||||
if base_model_id == "":
|
||||
base_model_id = args.hf_model_id
|
||||
|
||||
variant, version = get_variant_version(base_model_id)
|
||||
|
||||
if version == "inpaint_v1":
|
||||
version = "v1_4"
|
||||
elif version == "inpaint_v2":
|
||||
version = "v2_1base"
|
||||
|
||||
config_bucket = "gs://shark_tank/sd_tuned_configs/"
|
||||
|
||||
device, device_spec_args = get_device_args()
|
||||
spec = ""
|
||||
if device_spec_args:
|
||||
spec = device_spec_args[-1].split("=")[-1].strip()
|
||||
if device == "vulkan":
|
||||
spec = spec.split("-")[0]
|
||||
|
||||
if args.annotation_model == "vae":
|
||||
if not spec or spec in ["sm_80"]:
|
||||
config_name = (
|
||||
f"{args.annotation_model}_{args.precision}_{device}.json"
|
||||
)
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
|
||||
else:
|
||||
if not spec or spec in ["sm_80"]:
|
||||
if (
|
||||
version in ["v2_1", "v2_1base"]
|
||||
and args.height == 768
|
||||
and args.width == 768
|
||||
):
|
||||
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
elif spec in ["rdna3"] and version in [
|
||||
"v2_1",
|
||||
"v2_1base",
|
||||
"v1_4",
|
||||
"v1_5",
|
||||
]:
|
||||
config_name = (
|
||||
f"{args.annotation_model}_"
|
||||
f"{version}_"
|
||||
f"{args.max_length}_"
|
||||
f"{args.precision}_"
|
||||
f"{device}_"
|
||||
f"{spec}_"
|
||||
f"{args.width}x{args.height}.json"
|
||||
)
|
||||
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
|
||||
config_name = (
|
||||
f"{args.annotation_model}_"
|
||||
f"{version}_"
|
||||
f"{args.precision}_"
|
||||
f"{device}_"
|
||||
f"{spec}_"
|
||||
f"{args.width}x{args.height}.json"
|
||||
)
|
||||
else:
|
||||
config_name = (
|
||||
f"{args.annotation_model}_"
|
||||
f"{version}_"
|
||||
f"{args.precision}_"
|
||||
f"{device}_"
|
||||
f"{spec}.json"
|
||||
)
|
||||
|
||||
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading lowering config file from ", lowering_config_dir)
|
||||
full_gs_url = config_bucket + config_name
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
return lowering_config_dir
|
||||
|
||||
|
||||
# Annotate the model with Winograd attribute on selected conv ops
|
||||
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
|
||||
with create_context() as ctx:
|
||||
winograd_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=input_mlir,
|
||||
config_path=winograd_config_dir,
|
||||
search_op="conv",
|
||||
winograd=True,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
winograd_model.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if args.save_annotation:
|
||||
if model_name.split("_")[-1] != "tuned":
|
||||
out_file_path = os.path.join(
|
||||
args.annotation_output, model_name + "_tuned_torch.mlir"
|
||||
)
|
||||
else:
|
||||
out_file_path = os.path.join(
|
||||
args.annotation_output, model_name + "_torch.mlir"
|
||||
)
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(winograd_model))
|
||||
f.close()
|
||||
|
||||
return bytecode
|
||||
|
||||
|
||||
def dump_after_mlir(input_mlir, use_winograd):
|
||||
import iree.compiler as ireec
|
||||
|
||||
device, device_spec_args = get_device_args()
|
||||
if use_winograd:
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module"
|
||||
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
|
||||
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
|
||||
"iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
)
|
||||
else:
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module"
|
||||
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
|
||||
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
)
|
||||
|
||||
dump_module = ireec.compile_str(
|
||||
input_mlir,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=device_spec_args
|
||||
+ [
|
||||
preprocess_flag,
|
||||
"--compile-to=preprocessing",
|
||||
],
|
||||
)
|
||||
return dump_module
|
||||
|
||||
|
||||
# For Unet annotate the model with tuned lowering configs
|
||||
def annotate_with_lower_configs(
|
||||
input_mlir, lowering_config_dir, model_name, use_winograd
|
||||
):
|
||||
# Dump IR after padding/img2col/winograd passes
|
||||
dump_module = dump_after_mlir(input_mlir, use_winograd)
|
||||
print("Applying tuned configs on", model_name)
|
||||
|
||||
# Annotate the model with lowering configs in the config file
|
||||
with create_context() as ctx:
|
||||
tuned_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=dump_module,
|
||||
config_path=lowering_config_dir,
|
||||
search_op="all",
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
tuned_model.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if args.save_annotation:
|
||||
if model_name.split("_")[-1] != "tuned":
|
||||
out_file_path = (
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
)
|
||||
else:
|
||||
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(tuned_model))
|
||||
f.close()
|
||||
|
||||
return bytecode
|
||||
|
||||
|
||||
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
|
||||
device = get_device()
|
||||
if args.annotation_model == "unet" and device == "vulkan":
|
||||
use_winograd = True
|
||||
winograd_config_dir = load_winograd_configs()
|
||||
winograd_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
winograd_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
elif args.annotation_model == "vae" and device == "vulkan":
|
||||
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
|
||||
use_winograd = True
|
||||
winograd_config_dir = load_winograd_configs()
|
||||
tuned_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
else:
|
||||
tuned_model = mlir_model
|
||||
else:
|
||||
use_winograd = False
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
mlir_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
return tuned_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlir_model, model_name = load_model_from_tank()
|
||||
sd_model_annotation(mlir_model, model_name)
|
||||
@@ -1,771 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from apps.stable_diffusion.src.utils.resamplers import resampler_list
|
||||
|
||||
|
||||
def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
|
||||
|
||||
def is_valid_file(arg):
|
||||
if not os.path.exists(arg):
|
||||
return None
|
||||
else:
|
||||
return arg
|
||||
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"-a",
|
||||
"--app",
|
||||
default="txt2img",
|
||||
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
|
||||
)
|
||||
p.add_argument(
|
||||
"-p",
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=[
|
||||
"a photo taken of the front of a super-car drifting on a road near "
|
||||
"mountains at high speeds with smokes coming off the tires, front "
|
||||
"angle, front point of view, trees in the mountains of the "
|
||||
"background, ((sharp focus))"
|
||||
],
|
||||
help="Text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative_prompts",
|
||||
nargs="+",
|
||||
default=[
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
|
||||
"blurry, ugly, blur, oversaturated, cropped"
|
||||
],
|
||||
help="Text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--img_path",
|
||||
type=str,
|
||||
help="Path to the image input for img2img/inpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="The number of steps to do the sampling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=str,
|
||||
default=-1,
|
||||
help="The seed or list of seeds to use. -1 for a random one.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=range(1, 4),
|
||||
help="The number of inferences to be made in a single `batch_count`.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 1025, 8),
|
||||
help="The height of the output image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 1025, 8),
|
||||
help="The width of the output image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="The value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--noise_level",
|
||||
type=int,
|
||||
default=20,
|
||||
help="The value to be used for noise level of upscaler.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Max length of the tokenizer output, options are 64 and 77.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_embeddings_multiples",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The max multiple length of prompt embeddings compared to the max "
|
||||
"output length of text encoder.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="The strength of change applied on the given input image for "
|
||||
"img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_hiresfix",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Use Hires Fix to do higher resolution images, while trying to "
|
||||
"avoid the issues that come with it. This is accomplished by first "
|
||||
"generating an image using txt2img, then running it through img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_height",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The height of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_width",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The width of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_strength",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="The denoising strength to apply for the Hires Fix.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--resample_type",
|
||||
type=str,
|
||||
default="Nearest Neighbor",
|
||||
choices=resampler_list,
|
||||
help="The resample type to use when resizing an image before being run "
|
||||
"through stable diffusion.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Stable Diffusion Training Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--lora_save_dir",
|
||||
type=str,
|
||||
default="models/lora/",
|
||||
help="Directory to save the lora fine tuned model.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--training_images_dir",
|
||||
type=str,
|
||||
default="models/lora/training_images/",
|
||||
help="Directory containing images that are an example of the prompt.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--training_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The number of steps to train.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Inpainting and Outpainting Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--mask_path",
|
||||
type=str,
|
||||
help="Path to the mask image input for inpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--inpaint_full_res",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If inpaint only masked area or whole picture.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--inpaint_full_res_padding",
|
||||
type=int,
|
||||
default=32,
|
||||
choices=range(0, 257, 4),
|
||||
help="Number of pixels for only masked padding.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--pixels",
|
||||
type=int,
|
||||
default=128,
|
||||
choices=range(8, 257, 8),
|
||||
help="Number of expended pixels for one direction for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--mask_blur",
|
||||
type=int,
|
||||
default=8,
|
||||
choices=range(0, 65),
|
||||
help="Number of blur pixels for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--left",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If extend left for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--right",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If extend right for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--up",
|
||||
"--top",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If extend top for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--down",
|
||||
"--bottom",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If extend bottom for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--noise_q",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fall-off exponent for outpainting (lower=higher detail) "
|
||||
"(min=0.0, max=4.0).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--color_variation",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Color variation for outpainting (min=0.0, max=1.0).",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Model Config and Usage Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="vulkan", help="Device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp16", help="Precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Imports the model from torch module to shark_module otherwise "
|
||||
"downloads the model from shark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Attempts to load the model from a precompiled flat-buffer "
|
||||
"and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Saves the compiled flat-buffer to the local directory.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_base_vae",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Do conversion from the VAE output to pixel space on cpu.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="SharkEulerDiscrete",
|
||||
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
|
||||
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
|
||||
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
|
||||
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
|
||||
"HeunDiscrete].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_img_format",
|
||||
type=str,
|
||||
default="png",
|
||||
help="Specify the format in which output image is save. "
|
||||
"Supported options: jpg / png.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory path to save the output images and json.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--batch_count",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of batches to be generated with random seeds in "
|
||||
"single execution.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--repeatable_seeds",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="The seed of the first batch will be used as the rng seed to "
|
||||
"generate the subsequent seeds for subsequent batches in that run.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ckpt_loc",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to SD's .ckpt file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--custom_vae",
|
||||
type=str,
|
||||
default="",
|
||||
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
|
||||
"needs to be plugged in.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_model_id",
|
||||
type=str,
|
||||
default="stabilityai/stable-diffusion-2-1-base",
|
||||
help="The repo-id of hugging face.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--low_cpu_mem_usage",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use the accelerate package to reduce cpu memory consumption.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--attention_slicing",
|
||||
type=str,
|
||||
default="none",
|
||||
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
|
||||
"or an integer).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_stencil",
|
||||
choices=["canny", "openpose", "scribble", "zoedepth"],
|
||||
help="Enable the stencil feature.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--control_mode",
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
default="Balanced",
|
||||
help="How Controlnet injection should be prioritized.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_lora",
|
||||
type=str,
|
||||
default="",
|
||||
help="Use standalone LoRA weight using a HF ID or a checkpoint "
|
||||
"file (~3 MB).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_quantize",
|
||||
type=str,
|
||||
default="none",
|
||||
help="Runs the quantized version of stable diffusion model. "
|
||||
"This is currently in experimental phase. "
|
||||
"Currently, only runs the stable-diffusion-2-1-base model in "
|
||||
"int8 quantization.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ondemand",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Load and unload models for low VRAM.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device_allocator_heap_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify heap key for device caching allocator."
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--autogen",
|
||||
type=bool,
|
||||
default="False",
|
||||
help="Only used for a gradio workaround.",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree_vulkan_target_triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_metal_target_platform",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for metal.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Misc. Debug and Optimization flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--use_compiled_scheduler",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use the default scheduler precompiled into the model if available.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
help="Specify where to save downloaded shark_tank artifacts. "
|
||||
"If this is not set, the default is ~/.local/shark_tank/.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dump_isa",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled call amdllpc to get ISA dumps. "
|
||||
"Use with dispatch benchmarks.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks",
|
||||
default=None,
|
||||
help="Dispatches to return benchmark data on. "
|
||||
'Use "All" for all, and None for none.',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="temp_dispatch_benchmarks",
|
||||
help="Directory where you want to store dispatch data "
|
||||
'generated with "--dispatch_benchmarks".',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--enable_rgp",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for inserting debug frames between iterations "
|
||||
"for use with rgp.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hide_steps",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for hiding the details of iteration/sec for each step.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--warmup_count",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Flag setting warmup count for CLIP and VAE [>= 0].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--clear_all",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to clear all mlir and vmfb from common locations. "
|
||||
"Recompiling will take several minutes.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_metadata_to_json",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether or not to save a generation information "
|
||||
"json file with the image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--write_metadata_to_png",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether or not to save generation information in "
|
||||
"PNG chunk text to generated images.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If import_mlir is True, saves mlir via the debug option "
|
||||
"in shark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--compile_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to toggle debug assert/verify flags for imported IR in the"
|
||||
"iree-compiler. Default to false.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_constant_folding",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Controls constant folding in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--data_tiling",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Controls data tiling in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Web UI flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--progress_bar",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for removing the progress bar animation during "
|
||||
"image generation.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory where all .ckpts are stored in order to populate "
|
||||
"them in the web UI.",
|
||||
)
|
||||
# TODO: replace API flag when these can be run together
|
||||
p.add_argument(
|
||||
"--ui",
|
||||
type=str,
|
||||
default="app" if os.name == "nt" else "web",
|
||||
help="One of: [api, app, web].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--share",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for generating a public URL.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=8080,
|
||||
help="Flag for setting server port.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--api",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling rest API.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--api_accept_origin",
|
||||
action="append",
|
||||
type=str,
|
||||
help="An origin to be accepted by the REST api for Cross Origin"
|
||||
"Resource Sharing (CORS). Use multiple times for multiple origins, "
|
||||
'or use --api_accept_origin="*" to accept all origins. If no origins '
|
||||
"are set no CORS headers will be returned by the api. Use, for "
|
||||
"instance, if you need to access the REST api from Javascript running "
|
||||
"in a web browser.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling debugging log in WebUI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for removing the output gallery tab, and avoid exposing "
|
||||
"images under --output_dir in the UI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery_followlinks",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether the output gallery tab in the UI should "
|
||||
"follow symlinks when listing subdirectories under --output_dir.",
|
||||
)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# SD model auto-annotation flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_output",
|
||||
type=path_expand,
|
||||
default="./",
|
||||
help="Directory to save the annotated mlir file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_model",
|
||||
type=str,
|
||||
default="unet",
|
||||
help="Options are unet and vae.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_annotation",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Save annotated mlir file.",
|
||||
)
|
||||
##############################################################################
|
||||
# SD model auto-tuner flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--tuned_config_dir",
|
||||
type=path_expand,
|
||||
default="./",
|
||||
help="Directory to save the tuned config file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--num_iters",
|
||||
type=int,
|
||||
default=400,
|
||||
help="Number of iterations for tuning.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--search_op",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Op to be optimized, options are matmul, bmm, conv and all.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# DocuChat Flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--run_docuchat_web",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Specifies whether the docuchat's web version is running or not.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# rocm Flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree_rocm_target_chip",
|
||||
type=str,
|
||||
default="",
|
||||
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
|
||||
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
if args.import_debug:
|
||||
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
|
||||
os.getcwd(), args.hf_model_id.replace("/", "_")
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector
|
||||
from apps.stable_diffusion.src.utils.stencils.zoe import ZoeDetector
|
||||
@@ -1,6 +0,0 @@
|
||||
import cv2
|
||||
|
||||
|
||||
class CannyDetector:
|
||||
def __call__(self, img, low_threshold, high_threshold):
|
||||
return cv2.Canny(img, low_threshold, high_threshold)
|
||||
@@ -1,62 +0,0 @@
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# from annotator.util import annotator_ckpts_path
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose.body import Body
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose.hand import Hand
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
|
||||
draw_bodypose,
|
||||
draw_handpose,
|
||||
handDetect,
|
||||
)
|
||||
|
||||
|
||||
body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
|
||||
hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth"
|
||||
|
||||
|
||||
class OpenposeDetector:
|
||||
def __init__(self):
|
||||
cwd = Path.cwd()
|
||||
ckpt_path = Path(cwd, "stencil_annotator")
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
body_modelpath = ckpt_path / "body_pose_model.pth"
|
||||
hand_modelpath = ckpt_path / "hand_pose_model.pth"
|
||||
|
||||
if not body_modelpath.is_file():
|
||||
r = requests.get(body_model_path, allow_redirects=True)
|
||||
open(body_modelpath, "wb").write(r.content)
|
||||
if not hand_modelpath.is_file():
|
||||
r = requests.get(hand_model_path, allow_redirects=True)
|
||||
open(hand_modelpath, "wb").write(r.content)
|
||||
|
||||
self.body_estimation = Body(body_modelpath)
|
||||
self.hand_estimation = Hand(hand_modelpath)
|
||||
|
||||
def __call__(self, oriImg, hand=False):
|
||||
oriImg = oriImg[:, :, ::-1].copy()
|
||||
with torch.no_grad():
|
||||
candidate, subset = self.body_estimation(oriImg)
|
||||
canvas = np.zeros_like(oriImg)
|
||||
canvas = draw_bodypose(canvas, candidate, subset)
|
||||
if hand:
|
||||
hands_list = handDetect(candidate, subset, oriImg)
|
||||
all_hand_peaks = []
|
||||
for x, y, w, is_left in hands_list:
|
||||
peaks = self.hand_estimation(
|
||||
oriImg[y : y + w, x : x + w, :]
|
||||
)
|
||||
peaks[:, 0] = np.where(
|
||||
peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x
|
||||
)
|
||||
peaks[:, 1] = np.where(
|
||||
peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y
|
||||
)
|
||||
all_hand_peaks.append(peaks)
|
||||
canvas = draw_handpose(canvas, all_hand_peaks)
|
||||
return canvas, dict(
|
||||
candidate=candidate.tolist(), subset=subset.tolist()
|
||||
)
|
||||
@@ -1,499 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
|
||||
make_layers,
|
||||
transfer,
|
||||
padRightDownCorner,
|
||||
)
|
||||
|
||||
|
||||
class BodyPoseModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(BodyPoseModel, self).__init__()
|
||||
|
||||
# these layers have no relu layer
|
||||
no_relu_layers = [
|
||||
"conv5_5_CPM_L1",
|
||||
"conv5_5_CPM_L2",
|
||||
"Mconv7_stage2_L1",
|
||||
"Mconv7_stage2_L2",
|
||||
"Mconv7_stage3_L1",
|
||||
"Mconv7_stage3_L2",
|
||||
"Mconv7_stage4_L1",
|
||||
"Mconv7_stage4_L2",
|
||||
"Mconv7_stage5_L1",
|
||||
"Mconv7_stage5_L2",
|
||||
"Mconv7_stage6_L1",
|
||||
"Mconv7_stage6_L1",
|
||||
]
|
||||
blocks = {}
|
||||
block0 = OrderedDict(
|
||||
[
|
||||
("conv1_1", [3, 64, 3, 1, 1]),
|
||||
("conv1_2", [64, 64, 3, 1, 1]),
|
||||
("pool1_stage1", [2, 2, 0]),
|
||||
("conv2_1", [64, 128, 3, 1, 1]),
|
||||
("conv2_2", [128, 128, 3, 1, 1]),
|
||||
("pool2_stage1", [2, 2, 0]),
|
||||
("conv3_1", [128, 256, 3, 1, 1]),
|
||||
("conv3_2", [256, 256, 3, 1, 1]),
|
||||
("conv3_3", [256, 256, 3, 1, 1]),
|
||||
("conv3_4", [256, 256, 3, 1, 1]),
|
||||
("pool3_stage1", [2, 2, 0]),
|
||||
("conv4_1", [256, 512, 3, 1, 1]),
|
||||
("conv4_2", [512, 512, 3, 1, 1]),
|
||||
("conv4_3_CPM", [512, 256, 3, 1, 1]),
|
||||
("conv4_4_CPM", [256, 128, 3, 1, 1]),
|
||||
]
|
||||
)
|
||||
|
||||
# Stage 1
|
||||
block1_1 = OrderedDict(
|
||||
[
|
||||
("conv5_1_CPM_L1", [128, 128, 3, 1, 1]),
|
||||
("conv5_2_CPM_L1", [128, 128, 3, 1, 1]),
|
||||
("conv5_3_CPM_L1", [128, 128, 3, 1, 1]),
|
||||
("conv5_4_CPM_L1", [128, 512, 1, 1, 0]),
|
||||
("conv5_5_CPM_L1", [512, 38, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
block1_2 = OrderedDict(
|
||||
[
|
||||
("conv5_1_CPM_L2", [128, 128, 3, 1, 1]),
|
||||
("conv5_2_CPM_L2", [128, 128, 3, 1, 1]),
|
||||
("conv5_3_CPM_L2", [128, 128, 3, 1, 1]),
|
||||
("conv5_4_CPM_L2", [128, 512, 1, 1, 0]),
|
||||
("conv5_5_CPM_L2", [512, 19, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
blocks["block1_1"] = block1_1
|
||||
blocks["block1_2"] = block1_2
|
||||
|
||||
self.model0 = make_layers(block0, no_relu_layers)
|
||||
|
||||
# Stages 2 - 6
|
||||
for i in range(2, 7):
|
||||
blocks["block%d_1" % i] = OrderedDict(
|
||||
[
|
||||
("Mconv1_stage%d_L1" % i, [185, 128, 7, 1, 3]),
|
||||
("Mconv2_stage%d_L1" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv3_stage%d_L1" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv4_stage%d_L1" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv5_stage%d_L1" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv6_stage%d_L1" % i, [128, 128, 1, 1, 0]),
|
||||
("Mconv7_stage%d_L1" % i, [128, 38, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
blocks["block%d_2" % i] = OrderedDict(
|
||||
[
|
||||
("Mconv1_stage%d_L2" % i, [185, 128, 7, 1, 3]),
|
||||
("Mconv2_stage%d_L2" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv3_stage%d_L2" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv4_stage%d_L2" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv5_stage%d_L2" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv6_stage%d_L2" % i, [128, 128, 1, 1, 0]),
|
||||
("Mconv7_stage%d_L2" % i, [128, 19, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
for k in blocks.keys():
|
||||
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
||||
|
||||
self.model1_1 = blocks["block1_1"]
|
||||
self.model2_1 = blocks["block2_1"]
|
||||
self.model3_1 = blocks["block3_1"]
|
||||
self.model4_1 = blocks["block4_1"]
|
||||
self.model5_1 = blocks["block5_1"]
|
||||
self.model6_1 = blocks["block6_1"]
|
||||
|
||||
self.model1_2 = blocks["block1_2"]
|
||||
self.model2_2 = blocks["block2_2"]
|
||||
self.model3_2 = blocks["block3_2"]
|
||||
self.model4_2 = blocks["block4_2"]
|
||||
self.model5_2 = blocks["block5_2"]
|
||||
self.model6_2 = blocks["block6_2"]
|
||||
|
||||
def forward(self, x):
|
||||
out1 = self.model0(x)
|
||||
|
||||
out1_1 = self.model1_1(out1)
|
||||
out1_2 = self.model1_2(out1)
|
||||
out2 = torch.cat([out1_1, out1_2, out1], 1)
|
||||
|
||||
out2_1 = self.model2_1(out2)
|
||||
out2_2 = self.model2_2(out2)
|
||||
out3 = torch.cat([out2_1, out2_2, out1], 1)
|
||||
|
||||
out3_1 = self.model3_1(out3)
|
||||
out3_2 = self.model3_2(out3)
|
||||
out4 = torch.cat([out3_1, out3_2, out1], 1)
|
||||
|
||||
out4_1 = self.model4_1(out4)
|
||||
out4_2 = self.model4_2(out4)
|
||||
out5 = torch.cat([out4_1, out4_2, out1], 1)
|
||||
|
||||
out5_1 = self.model5_1(out5)
|
||||
out5_2 = self.model5_2(out5)
|
||||
out6 = torch.cat([out5_1, out5_2, out1], 1)
|
||||
|
||||
out6_1 = self.model6_1(out6)
|
||||
out6_2 = self.model6_2(out6)
|
||||
|
||||
return out6_1, out6_2
|
||||
|
||||
|
||||
class Body(object):
|
||||
def __init__(self, model_path):
|
||||
self.model = BodyPoseModel()
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cuda()
|
||||
model_dict = transfer(self.model, torch.load(model_path))
|
||||
self.model.load_state_dict(model_dict)
|
||||
self.model.eval()
|
||||
|
||||
def __call__(self, oriImg):
|
||||
scale_search = [0.5]
|
||||
boxsize = 368
|
||||
stride = 8
|
||||
padValue = 128
|
||||
thre1 = 0.1
|
||||
thre2 = 0.05
|
||||
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
|
||||
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
|
||||
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
|
||||
|
||||
for m in range(len(multiplier)):
|
||||
scale = multiplier[m]
|
||||
imageToTest = cv2.resize(
|
||||
oriImg,
|
||||
(0, 0),
|
||||
fx=scale,
|
||||
fy=scale,
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
imageToTest_padded, pad = padRightDownCorner(
|
||||
imageToTest, stride, padValue
|
||||
)
|
||||
im = (
|
||||
np.transpose(
|
||||
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
|
||||
(3, 2, 0, 1),
|
||||
)
|
||||
/ 256
|
||||
- 0.5
|
||||
)
|
||||
im = np.ascontiguousarray(im)
|
||||
|
||||
data = torch.from_numpy(im).float()
|
||||
if torch.cuda.is_available():
|
||||
data = data.cuda()
|
||||
with torch.no_grad():
|
||||
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
|
||||
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
|
||||
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
|
||||
|
||||
# extract outputs, resize, and remove padding
|
||||
heatmap = np.transpose(
|
||||
np.squeeze(Mconv7_stage6_L2), (1, 2, 0)
|
||||
) # output 1 is heatmaps
|
||||
heatmap = cv2.resize(
|
||||
heatmap,
|
||||
(0, 0),
|
||||
fx=stride,
|
||||
fy=stride,
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
heatmap = heatmap[
|
||||
: imageToTest_padded.shape[0] - pad[2],
|
||||
: imageToTest_padded.shape[1] - pad[3],
|
||||
:,
|
||||
]
|
||||
heatmap = cv2.resize(
|
||||
heatmap,
|
||||
(oriImg.shape[1], oriImg.shape[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
|
||||
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
|
||||
paf = np.transpose(
|
||||
np.squeeze(Mconv7_stage6_L1), (1, 2, 0)
|
||||
) # output 0 is PAFs
|
||||
paf = cv2.resize(
|
||||
paf,
|
||||
(0, 0),
|
||||
fx=stride,
|
||||
fy=stride,
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
paf = paf[
|
||||
: imageToTest_padded.shape[0] - pad[2],
|
||||
: imageToTest_padded.shape[1] - pad[3],
|
||||
:,
|
||||
]
|
||||
paf = cv2.resize(
|
||||
paf,
|
||||
(oriImg.shape[1], oriImg.shape[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
|
||||
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
|
||||
paf_avg += +paf / len(multiplier)
|
||||
|
||||
all_peaks = []
|
||||
peak_counter = 0
|
||||
|
||||
for part in range(18):
|
||||
map_ori = heatmap_avg[:, :, part]
|
||||
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
||||
|
||||
map_left = np.zeros(one_heatmap.shape)
|
||||
map_left[1:, :] = one_heatmap[:-1, :]
|
||||
map_right = np.zeros(one_heatmap.shape)
|
||||
map_right[:-1, :] = one_heatmap[1:, :]
|
||||
map_up = np.zeros(one_heatmap.shape)
|
||||
map_up[:, 1:] = one_heatmap[:, :-1]
|
||||
map_down = np.zeros(one_heatmap.shape)
|
||||
map_down[:, :-1] = one_heatmap[:, 1:]
|
||||
|
||||
peaks_binary = np.logical_and.reduce(
|
||||
(
|
||||
one_heatmap >= map_left,
|
||||
one_heatmap >= map_right,
|
||||
one_heatmap >= map_up,
|
||||
one_heatmap >= map_down,
|
||||
one_heatmap > thre1,
|
||||
)
|
||||
)
|
||||
peaks = list(
|
||||
zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
|
||||
) # note reverse
|
||||
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
|
||||
peak_id = range(peak_counter, peak_counter + len(peaks))
|
||||
peaks_with_score_and_id = [
|
||||
peaks_with_score[i] + (peak_id[i],)
|
||||
for i in range(len(peak_id))
|
||||
]
|
||||
|
||||
all_peaks.append(peaks_with_score_and_id)
|
||||
peak_counter += len(peaks)
|
||||
|
||||
# find connection in the specified sequence, center 29 is in the position 15
|
||||
limbSeq = [
|
||||
[2, 3],
|
||||
[2, 6],
|
||||
[3, 4],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[2, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[2, 12],
|
||||
[12, 13],
|
||||
[13, 14],
|
||||
[2, 1],
|
||||
[1, 15],
|
||||
[15, 17],
|
||||
[1, 16],
|
||||
[16, 18],
|
||||
[3, 17],
|
||||
[6, 18],
|
||||
]
|
||||
# the middle joints heatmap correpondence
|
||||
mapIdx = [
|
||||
[31, 32],
|
||||
[39, 40],
|
||||
[33, 34],
|
||||
[35, 36],
|
||||
[41, 42],
|
||||
[43, 44],
|
||||
[19, 20],
|
||||
[21, 22],
|
||||
[23, 24],
|
||||
[25, 26],
|
||||
[27, 28],
|
||||
[29, 30],
|
||||
[47, 48],
|
||||
[49, 50],
|
||||
[53, 54],
|
||||
[51, 52],
|
||||
[55, 56],
|
||||
[37, 38],
|
||||
[45, 46],
|
||||
]
|
||||
|
||||
connection_all = []
|
||||
special_k = []
|
||||
mid_num = 10
|
||||
|
||||
for k in range(len(mapIdx)):
|
||||
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
|
||||
candA = all_peaks[limbSeq[k][0] - 1]
|
||||
candB = all_peaks[limbSeq[k][1] - 1]
|
||||
nA = len(candA)
|
||||
nB = len(candB)
|
||||
indexA, indexB = limbSeq[k]
|
||||
if nA != 0 and nB != 0:
|
||||
connection_candidate = []
|
||||
for i in range(nA):
|
||||
for j in range(nB):
|
||||
vec = np.subtract(candB[j][:2], candA[i][:2])
|
||||
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
|
||||
norm = max(0.001, norm)
|
||||
vec = np.divide(vec, norm)
|
||||
|
||||
startend = list(
|
||||
zip(
|
||||
np.linspace(
|
||||
candA[i][0], candB[j][0], num=mid_num
|
||||
),
|
||||
np.linspace(
|
||||
candA[i][1], candB[j][1], num=mid_num
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
vec_x = np.array(
|
||||
[
|
||||
score_mid[
|
||||
int(round(startend[I][1])),
|
||||
int(round(startend[I][0])),
|
||||
0,
|
||||
]
|
||||
for I in range(len(startend))
|
||||
]
|
||||
)
|
||||
vec_y = np.array(
|
||||
[
|
||||
score_mid[
|
||||
int(round(startend[I][1])),
|
||||
int(round(startend[I][0])),
|
||||
1,
|
||||
]
|
||||
for I in range(len(startend))
|
||||
]
|
||||
)
|
||||
|
||||
score_midpts = np.multiply(
|
||||
vec_x, vec[0]
|
||||
) + np.multiply(vec_y, vec[1])
|
||||
score_with_dist_prior = sum(score_midpts) / len(
|
||||
score_midpts
|
||||
) + min(0.5 * oriImg.shape[0] / norm - 1, 0)
|
||||
criterion1 = len(
|
||||
np.nonzero(score_midpts > thre2)[0]
|
||||
) > 0.8 * len(score_midpts)
|
||||
criterion2 = score_with_dist_prior > 0
|
||||
if criterion1 and criterion2:
|
||||
connection_candidate.append(
|
||||
[
|
||||
i,
|
||||
j,
|
||||
score_with_dist_prior,
|
||||
score_with_dist_prior
|
||||
+ candA[i][2]
|
||||
+ candB[j][2],
|
||||
]
|
||||
)
|
||||
|
||||
connection_candidate = sorted(
|
||||
connection_candidate, key=lambda x: x[2], reverse=True
|
||||
)
|
||||
connection = np.zeros((0, 5))
|
||||
for c in range(len(connection_candidate)):
|
||||
i, j, s = connection_candidate[c][0:3]
|
||||
if i not in connection[:, 3] and j not in connection[:, 4]:
|
||||
connection = np.vstack(
|
||||
[connection, [candA[i][3], candB[j][3], s, i, j]]
|
||||
)
|
||||
if len(connection) >= min(nA, nB):
|
||||
break
|
||||
|
||||
connection_all.append(connection)
|
||||
else:
|
||||
special_k.append(k)
|
||||
connection_all.append([])
|
||||
|
||||
# last number in each row is the total parts number of that person
|
||||
# the second last number in each row is the score of the overall configuration
|
||||
subset = -1 * np.ones((0, 20))
|
||||
candidate = np.array(
|
||||
[item for sublist in all_peaks for item in sublist]
|
||||
)
|
||||
|
||||
for k in range(len(mapIdx)):
|
||||
if k not in special_k:
|
||||
partAs = connection_all[k][:, 0]
|
||||
partBs = connection_all[k][:, 1]
|
||||
indexA, indexB = np.array(limbSeq[k]) - 1
|
||||
|
||||
for i in range(len(connection_all[k])): # = 1:size(temp,1)
|
||||
found = 0
|
||||
subset_idx = [-1, -1]
|
||||
for j in range(len(subset)): # 1:size(subset,1):
|
||||
if (
|
||||
subset[j][indexA] == partAs[i]
|
||||
or subset[j][indexB] == partBs[i]
|
||||
):
|
||||
subset_idx[found] = j
|
||||
found += 1
|
||||
|
||||
if found == 1:
|
||||
j = subset_idx[0]
|
||||
if subset[j][indexB] != partBs[i]:
|
||||
subset[j][indexB] = partBs[i]
|
||||
subset[j][-1] += 1
|
||||
subset[j][-2] += (
|
||||
candidate[partBs[i].astype(int), 2]
|
||||
+ connection_all[k][i][2]
|
||||
)
|
||||
elif found == 2: # if found 2 and disjoint, merge them
|
||||
j1, j2 = subset_idx
|
||||
membership = (
|
||||
(subset[j1] >= 0).astype(int)
|
||||
+ (subset[j2] >= 0).astype(int)
|
||||
)[:-2]
|
||||
if len(np.nonzero(membership == 2)[0]) == 0: # merge
|
||||
subset[j1][:-2] += subset[j2][:-2] + 1
|
||||
subset[j1][-2:] += subset[j2][-2:]
|
||||
subset[j1][-2] += connection_all[k][i][2]
|
||||
subset = np.delete(subset, j2, 0)
|
||||
else: # as like found == 1
|
||||
subset[j1][indexB] = partBs[i]
|
||||
subset[j1][-1] += 1
|
||||
subset[j1][-2] += (
|
||||
candidate[partBs[i].astype(int), 2]
|
||||
+ connection_all[k][i][2]
|
||||
)
|
||||
|
||||
# if find no partA in the subset, create a new subset
|
||||
elif not found and k < 17:
|
||||
row = -1 * np.ones(20)
|
||||
row[indexA] = partAs[i]
|
||||
row[indexB] = partBs[i]
|
||||
row[-1] = 2
|
||||
row[-2] = (
|
||||
sum(
|
||||
candidate[
|
||||
connection_all[k][i, :2].astype(int), 2
|
||||
]
|
||||
)
|
||||
+ connection_all[k][i][2]
|
||||
)
|
||||
subset = np.vstack([subset, row])
|
||||
# delete some rows of subset which has few parts occur
|
||||
deleteIdx = []
|
||||
for i in range(len(subset)):
|
||||
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
|
||||
deleteIdx.append(i)
|
||||
subset = np.delete(subset, deleteIdx, axis=0)
|
||||
|
||||
# candidate: x, y, score, id
|
||||
return candidate, subset
|
||||
@@ -1,205 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from skimage.measure import label
|
||||
from collections import OrderedDict
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
|
||||
make_layers,
|
||||
transfer,
|
||||
padRightDownCorner,
|
||||
npmax,
|
||||
)
|
||||
|
||||
|
||||
class HandPoseModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(HandPoseModel, self).__init__()
|
||||
|
||||
# these layers have no relu layer
|
||||
no_relu_layers = [
|
||||
"conv6_2_CPM",
|
||||
"Mconv7_stage2",
|
||||
"Mconv7_stage3",
|
||||
"Mconv7_stage4",
|
||||
"Mconv7_stage5",
|
||||
"Mconv7_stage6",
|
||||
]
|
||||
# stage 1
|
||||
block1_0 = OrderedDict(
|
||||
[
|
||||
("conv1_1", [3, 64, 3, 1, 1]),
|
||||
("conv1_2", [64, 64, 3, 1, 1]),
|
||||
("pool1_stage1", [2, 2, 0]),
|
||||
("conv2_1", [64, 128, 3, 1, 1]),
|
||||
("conv2_2", [128, 128, 3, 1, 1]),
|
||||
("pool2_stage1", [2, 2, 0]),
|
||||
("conv3_1", [128, 256, 3, 1, 1]),
|
||||
("conv3_2", [256, 256, 3, 1, 1]),
|
||||
("conv3_3", [256, 256, 3, 1, 1]),
|
||||
("conv3_4", [256, 256, 3, 1, 1]),
|
||||
("pool3_stage1", [2, 2, 0]),
|
||||
("conv4_1", [256, 512, 3, 1, 1]),
|
||||
("conv4_2", [512, 512, 3, 1, 1]),
|
||||
("conv4_3", [512, 512, 3, 1, 1]),
|
||||
("conv4_4", [512, 512, 3, 1, 1]),
|
||||
("conv5_1", [512, 512, 3, 1, 1]),
|
||||
("conv5_2", [512, 512, 3, 1, 1]),
|
||||
("conv5_3_CPM", [512, 128, 3, 1, 1]),
|
||||
]
|
||||
)
|
||||
|
||||
block1_1 = OrderedDict(
|
||||
[
|
||||
("conv6_1_CPM", [128, 512, 1, 1, 0]),
|
||||
("conv6_2_CPM", [512, 22, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
blocks = {}
|
||||
blocks["block1_0"] = block1_0
|
||||
blocks["block1_1"] = block1_1
|
||||
|
||||
# stage 2-6
|
||||
for i in range(2, 7):
|
||||
blocks["block%d" % i] = OrderedDict(
|
||||
[
|
||||
("Mconv1_stage%d" % i, [150, 128, 7, 1, 3]),
|
||||
("Mconv2_stage%d" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv3_stage%d" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv4_stage%d" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv5_stage%d" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv6_stage%d" % i, [128, 128, 1, 1, 0]),
|
||||
("Mconv7_stage%d" % i, [128, 22, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
for k in blocks.keys():
|
||||
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
||||
|
||||
self.model1_0 = blocks["block1_0"]
|
||||
self.model1_1 = blocks["block1_1"]
|
||||
self.model2 = blocks["block2"]
|
||||
self.model3 = blocks["block3"]
|
||||
self.model4 = blocks["block4"]
|
||||
self.model5 = blocks["block5"]
|
||||
self.model6 = blocks["block6"]
|
||||
|
||||
def forward(self, x):
|
||||
out1_0 = self.model1_0(x)
|
||||
out1_1 = self.model1_1(out1_0)
|
||||
concat_stage2 = torch.cat([out1_1, out1_0], 1)
|
||||
out_stage2 = self.model2(concat_stage2)
|
||||
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
|
||||
out_stage3 = self.model3(concat_stage3)
|
||||
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
|
||||
out_stage4 = self.model4(concat_stage4)
|
||||
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
|
||||
out_stage5 = self.model5(concat_stage5)
|
||||
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
|
||||
out_stage6 = self.model6(concat_stage6)
|
||||
return out_stage6
|
||||
|
||||
|
||||
class Hand(object):
|
||||
def __init__(self, model_path):
|
||||
self.model = HandPoseModel()
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cuda()
|
||||
model_dict = transfer(self.model, torch.load(model_path))
|
||||
self.model.load_state_dict(model_dict)
|
||||
self.model.eval()
|
||||
|
||||
def __call__(self, oriImg):
|
||||
scale_search = [0.5, 1.0, 1.5, 2.0]
|
||||
# scale_search = [0.5]
|
||||
boxsize = 368
|
||||
stride = 8
|
||||
padValue = 128
|
||||
thre = 0.05
|
||||
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
|
||||
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
|
||||
# paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
|
||||
|
||||
for m in range(len(multiplier)):
|
||||
scale = multiplier[m]
|
||||
imageToTest = cv2.resize(
|
||||
oriImg,
|
||||
(0, 0),
|
||||
fx=scale,
|
||||
fy=scale,
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
imageToTest_padded, pad = padRightDownCorner(
|
||||
imageToTest, stride, padValue
|
||||
)
|
||||
im = (
|
||||
np.transpose(
|
||||
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
|
||||
(3, 2, 0, 1),
|
||||
)
|
||||
/ 256
|
||||
- 0.5
|
||||
)
|
||||
im = np.ascontiguousarray(im)
|
||||
|
||||
data = torch.from_numpy(im).float()
|
||||
if torch.cuda.is_available():
|
||||
data = data.cuda()
|
||||
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
|
||||
with torch.no_grad():
|
||||
output = self.model(data).cpu().numpy()
|
||||
# output = self.model(data).numpy()q
|
||||
|
||||
# extract outputs, resize, and remove padding
|
||||
heatmap = np.transpose(
|
||||
np.squeeze(output), (1, 2, 0)
|
||||
) # output 1 is heatmaps
|
||||
heatmap = cv2.resize(
|
||||
heatmap,
|
||||
(0, 0),
|
||||
fx=stride,
|
||||
fy=stride,
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
heatmap = heatmap[
|
||||
: imageToTest_padded.shape[0] - pad[2],
|
||||
: imageToTest_padded.shape[1] - pad[3],
|
||||
:,
|
||||
]
|
||||
heatmap = cv2.resize(
|
||||
heatmap,
|
||||
(oriImg.shape[1], oriImg.shape[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
|
||||
heatmap_avg += heatmap / len(multiplier)
|
||||
|
||||
all_peaks = []
|
||||
for part in range(21):
|
||||
map_ori = heatmap_avg[:, :, part]
|
||||
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
||||
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
|
||||
# 全部小于阈值
|
||||
if np.sum(binary) == 0:
|
||||
all_peaks.append([0, 0])
|
||||
continue
|
||||
label_img, label_numbers = label(
|
||||
binary, return_num=True, connectivity=binary.ndim
|
||||
)
|
||||
max_index = (
|
||||
np.argmax(
|
||||
[
|
||||
np.sum(map_ori[label_img == i])
|
||||
for i in range(1, label_numbers + 1)
|
||||
]
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
label_img[label_img != max_index] = 0
|
||||
map_ori[label_img == 0] = 0
|
||||
|
||||
y, x = npmax(map_ori)
|
||||
all_peaks.append([x, y])
|
||||
return np.array(all_peaks)
|
||||
@@ -1,272 +0,0 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import cv2
|
||||
from collections import OrderedDict
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def make_layers(block, no_relu_layers):
|
||||
layers = []
|
||||
for layer_name, v in block.items():
|
||||
if "pool" in layer_name:
|
||||
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
|
||||
layers.append((layer_name, layer))
|
||||
else:
|
||||
conv2d = nn.Conv2d(
|
||||
in_channels=v[0],
|
||||
out_channels=v[1],
|
||||
kernel_size=v[2],
|
||||
stride=v[3],
|
||||
padding=v[4],
|
||||
)
|
||||
layers.append((layer_name, conv2d))
|
||||
if layer_name not in no_relu_layers:
|
||||
layers.append(("relu_" + layer_name, nn.ReLU(inplace=True)))
|
||||
|
||||
return nn.Sequential(OrderedDict(layers))
|
||||
|
||||
|
||||
def padRightDownCorner(img, stride, padValue):
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
|
||||
pad = 4 * [None]
|
||||
pad[0] = 0 # up
|
||||
pad[1] = 0 # left
|
||||
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
||||
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
||||
|
||||
img_padded = img
|
||||
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
|
||||
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
||||
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
|
||||
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
||||
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
|
||||
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
||||
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
|
||||
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
||||
|
||||
return img_padded, pad
|
||||
|
||||
|
||||
# transfer caffe model to pytorch which will match the layer name
|
||||
def transfer(model, model_weights):
|
||||
transfered_model_weights = {}
|
||||
for weights_name in model.state_dict().keys():
|
||||
transfered_model_weights[weights_name] = model_weights[
|
||||
".".join(weights_name.split(".")[1:])
|
||||
]
|
||||
return transfered_model_weights
|
||||
|
||||
|
||||
# draw the body keypoint and lims
|
||||
def draw_bodypose(canvas, candidate, subset):
|
||||
stickwidth = 4
|
||||
limbSeq = [
|
||||
[2, 3],
|
||||
[2, 6],
|
||||
[3, 4],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[2, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[2, 12],
|
||||
[12, 13],
|
||||
[13, 14],
|
||||
[2, 1],
|
||||
[1, 15],
|
||||
[15, 17],
|
||||
[1, 16],
|
||||
[16, 18],
|
||||
[3, 17],
|
||||
[6, 18],
|
||||
]
|
||||
|
||||
colors = [
|
||||
[255, 0, 0],
|
||||
[255, 85, 0],
|
||||
[255, 170, 0],
|
||||
[255, 255, 0],
|
||||
[170, 255, 0],
|
||||
[85, 255, 0],
|
||||
[0, 255, 0],
|
||||
[0, 255, 85],
|
||||
[0, 255, 170],
|
||||
[0, 255, 255],
|
||||
[0, 170, 255],
|
||||
[0, 85, 255],
|
||||
[0, 0, 255],
|
||||
[85, 0, 255],
|
||||
[170, 0, 255],
|
||||
[255, 0, 255],
|
||||
[255, 0, 170],
|
||||
[255, 0, 85],
|
||||
]
|
||||
for i in range(18):
|
||||
for n in range(len(subset)):
|
||||
index = int(subset[n][i])
|
||||
if index == -1:
|
||||
continue
|
||||
x, y = candidate[index][0:2]
|
||||
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
||||
for i in range(17):
|
||||
for n in range(len(subset)):
|
||||
index = subset[n][np.array(limbSeq[i]) - 1]
|
||||
if -1 in index:
|
||||
continue
|
||||
cur_canvas = canvas.copy()
|
||||
Y = candidate[index.astype(int), 0]
|
||||
X = candidate[index.astype(int), 1]
|
||||
mX = np.mean(X)
|
||||
mY = np.mean(Y)
|
||||
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
polygon = cv2.ellipse2Poly(
|
||||
(int(mY), int(mX)),
|
||||
(int(length / 2), stickwidth),
|
||||
int(angle),
|
||||
0,
|
||||
360,
|
||||
1,
|
||||
)
|
||||
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
|
||||
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
||||
return canvas
|
||||
|
||||
|
||||
# image drawed by opencv is not good.
|
||||
def draw_handpose(canvas, all_hand_peaks, show_number=False):
|
||||
edges = [
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[2, 3],
|
||||
[3, 4],
|
||||
[0, 5],
|
||||
[5, 6],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[0, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[11, 12],
|
||||
[0, 13],
|
||||
[13, 14],
|
||||
[14, 15],
|
||||
[15, 16],
|
||||
[0, 17],
|
||||
[17, 18],
|
||||
[18, 19],
|
||||
[19, 20],
|
||||
]
|
||||
|
||||
for peaks in all_hand_peaks:
|
||||
for ie, e in enumerate(edges):
|
||||
if np.sum(np.all(peaks[e], axis=1) == 0) == 0:
|
||||
x1, y1 = peaks[e[0]]
|
||||
x2, y2 = peaks[e[1]]
|
||||
cv2.line(
|
||||
canvas,
|
||||
(x1, y1),
|
||||
(x2, y2),
|
||||
matplotlib.colors.hsv_to_rgb(
|
||||
[ie / float(len(edges)), 1.0, 1.0]
|
||||
)
|
||||
* 255,
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
for i, keyponit in enumerate(peaks):
|
||||
x, y = keyponit
|
||||
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
||||
if show_number:
|
||||
cv2.putText(
|
||||
canvas,
|
||||
str(i),
|
||||
(x, y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.3,
|
||||
(0, 0, 0),
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
return canvas
|
||||
|
||||
|
||||
# detect hand according to body pose keypoints
|
||||
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
||||
def handDetect(candidate, subset, oriImg):
|
||||
# right hand: wrist 4, elbow 3, shoulder 2
|
||||
# left hand: wrist 7, elbow 6, shoulder 5
|
||||
ratioWristElbow = 0.33
|
||||
detect_result = []
|
||||
image_height, image_width = oriImg.shape[0:2]
|
||||
for person in subset.astype(int):
|
||||
# if any of three not detected
|
||||
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
||||
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
||||
if not (has_left or has_right):
|
||||
continue
|
||||
hands = []
|
||||
# left hand
|
||||
if has_left:
|
||||
left_shoulder_index, left_elbow_index, left_wrist_index = person[
|
||||
[5, 6, 7]
|
||||
]
|
||||
x1, y1 = candidate[left_shoulder_index][:2]
|
||||
x2, y2 = candidate[left_elbow_index][:2]
|
||||
x3, y3 = candidate[left_wrist_index][:2]
|
||||
hands.append([x1, y1, x2, y2, x3, y3, True])
|
||||
# right hand
|
||||
if has_right:
|
||||
(
|
||||
right_shoulder_index,
|
||||
right_elbow_index,
|
||||
right_wrist_index,
|
||||
) = person[[2, 3, 4]]
|
||||
x1, y1 = candidate[right_shoulder_index][:2]
|
||||
x2, y2 = candidate[right_elbow_index][:2]
|
||||
x3, y3 = candidate[right_wrist_index][:2]
|
||||
hands.append([x1, y1, x2, y2, x3, y3, False])
|
||||
|
||||
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
||||
x = x3 + ratioWristElbow * (x3 - x2)
|
||||
y = y3 + ratioWristElbow * (y3 - y2)
|
||||
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
||||
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
||||
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
||||
# x-y refers to the center --> offset to topLeft point
|
||||
x -= width / 2
|
||||
y -= width / 2 # width = height
|
||||
# overflow the image
|
||||
if x < 0:
|
||||
x = 0
|
||||
if y < 0:
|
||||
y = 0
|
||||
width1 = width
|
||||
width2 = width
|
||||
if x + width > image_width:
|
||||
width1 = image_width - x
|
||||
if y + width > image_height:
|
||||
width2 = image_height - y
|
||||
width = min(width1, width2)
|
||||
# the max hand box value is 20 pixels
|
||||
if width >= 20:
|
||||
detect_result.append([int(x), int(y), int(width), is_left])
|
||||
|
||||
"""
|
||||
return value: [[x, y, w, True if left hand else False]].
|
||||
width=height since the network require squared input.
|
||||
x, y is the coordinate of top left
|
||||
"""
|
||||
return detect_result
|
||||
|
||||
|
||||
# get max index of 2d array
|
||||
def npmax(array):
|
||||
arrayindex = array.argmax(1)
|
||||
arrayvalue = array.max(1)
|
||||
i = arrayvalue.argmax()
|
||||
j = arrayindex[i]
|
||||
return (i,)
|
||||
@@ -1,253 +0,0 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
import torchvision
|
||||
import time
|
||||
from apps.stable_diffusion.src.utils.stencils import (
|
||||
CannyDetector,
|
||||
OpenposeDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
|
||||
stencil = {}
|
||||
|
||||
|
||||
def save_img(img):
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
|
||||
subdir = Path(get_generated_imgs_path(), "preprocessed_control_hints")
|
||||
os.makedirs(subdir, exist_ok=True)
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(
|
||||
os.path.join(
|
||||
subdir, "controlnet_" + str(int(time.time())) + ".png"
|
||||
)
|
||||
)
|
||||
elif isinstance(img, np.ndarray):
|
||||
img = Image.fromarray(img)
|
||||
img.save(os.path.join(subdir, str(int(time.time())) + ".png"))
|
||||
else:
|
||||
converter = torchvision.transforms.ToPILImage()
|
||||
for i in img:
|
||||
converter(i).save(
|
||||
os.path.join(subdir, str(int(time.time())) + ".png")
|
||||
)
|
||||
|
||||
|
||||
def HWC3(x):
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
x = x[:, :, None]
|
||||
assert x.ndim == 3
|
||||
H, W, C = x.shape
|
||||
assert C == 1 or C == 3 or C == 4
|
||||
if C == 3:
|
||||
return x
|
||||
if C == 1:
|
||||
return np.concatenate([x, x, x], axis=2)
|
||||
if C == 4:
|
||||
color = x[:, :, 0:3].astype(np.float32)
|
||||
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
||||
y = color * alpha + 255.0 * (1.0 - alpha)
|
||||
y = y.clip(0, 255).astype(np.uint8)
|
||||
return y
|
||||
|
||||
|
||||
def controlnet_hint_reshaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt=1
|
||||
):
|
||||
channels = 3
|
||||
if isinstance(controlnet_hint, torch.Tensor):
|
||||
# torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt)
|
||||
shape_chw = (channels, height, width)
|
||||
shape_bchw = (1, channels, height, width)
|
||||
shape_nchw = (num_images_per_prompt, channels, height, width)
|
||||
if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]:
|
||||
controlnet_hint = controlnet_hint.to(
|
||||
dtype=dtype, device=torch.device("cpu")
|
||||
)
|
||||
if controlnet_hint.shape != shape_nchw:
|
||||
controlnet_hint = controlnet_hint.repeat(
|
||||
num_images_per_prompt, 1, 1, 1
|
||||
)
|
||||
return controlnet_hint
|
||||
else:
|
||||
return controlnet_hint_reshaping(
|
||||
Image.fromarray(controlnet_hint.detach().numpy()),
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt,
|
||||
)
|
||||
elif isinstance(controlnet_hint, np.ndarray):
|
||||
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
|
||||
# hwc is opencv compatible image format. Color channel must be BGR Format.
|
||||
if controlnet_hint.shape == (height, width):
|
||||
controlnet_hint = np.repeat(
|
||||
controlnet_hint[:, :, np.newaxis], channels, axis=2
|
||||
) # hw -> hwc(c==3)
|
||||
shape_hwc = (height, width, channels)
|
||||
shape_bhwc = (1, height, width, channels)
|
||||
shape_nhwc = (num_images_per_prompt, height, width, channels)
|
||||
if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]:
|
||||
controlnet_hint = torch.from_numpy(controlnet_hint.copy())
|
||||
controlnet_hint = controlnet_hint.to(
|
||||
dtype=dtype, device=torch.device("cpu")
|
||||
)
|
||||
controlnet_hint /= 255.0
|
||||
if controlnet_hint.shape != shape_nhwc:
|
||||
controlnet_hint = controlnet_hint.repeat(
|
||||
num_images_per_prompt, 1, 1, 1
|
||||
)
|
||||
controlnet_hint = controlnet_hint.permute(
|
||||
0, 3, 1, 2
|
||||
) # b h w c -> b c h w
|
||||
return controlnet_hint
|
||||
else:
|
||||
return controlnet_hint_reshaping(
|
||||
Image.fromarray(controlnet_hint),
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt,
|
||||
)
|
||||
|
||||
elif isinstance(controlnet_hint, Image.Image):
|
||||
controlnet_hint = controlnet_hint.convert(
|
||||
"RGB"
|
||||
) # make sure 3 channel RGB format
|
||||
if controlnet_hint.size == (width, height):
|
||||
controlnet_hint = np.array(controlnet_hint).astype(
|
||||
np.float16
|
||||
) # to numpy
|
||||
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
|
||||
return controlnet_hint_reshaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt
|
||||
)
|
||||
else:
|
||||
(hint_w, hint_h) = controlnet_hint.size
|
||||
left = int((hint_w - width) / 2)
|
||||
right = left + height
|
||||
controlnet_hint = controlnet_hint.crop((left, 0, right, hint_h))
|
||||
controlnet_hint = controlnet_hint.resize((width, height))
|
||||
return controlnet_hint_reshaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptible controlnet input types are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
|
||||
)
|
||||
|
||||
|
||||
def controlnet_hint_conversion(
|
||||
image, use_stencil, height, width, dtype, num_images_per_prompt=1
|
||||
):
|
||||
controlnet_hint = None
|
||||
match use_stencil:
|
||||
case "canny":
|
||||
print(
|
||||
"Converting controlnet hint to edge detection mask with canny preprocessor."
|
||||
)
|
||||
controlnet_hint = hint_canny(image)
|
||||
case "openpose":
|
||||
print(
|
||||
"Detecting human pose in controlnet hint with openpose preprocessor."
|
||||
)
|
||||
controlnet_hint = hint_openpose(image)
|
||||
case "scribble":
|
||||
print("Using your scribble as a controlnet hint.")
|
||||
controlnet_hint = hint_scribble(image)
|
||||
case "zoedepth":
|
||||
print(
|
||||
"Converting controlnet hint to a depth mapping with ZoeDepth."
|
||||
)
|
||||
controlnet_hint = hint_zoedepth(image)
|
||||
case _:
|
||||
return None
|
||||
controlnet_hint = controlnet_hint_reshaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt
|
||||
)
|
||||
return controlnet_hint
|
||||
|
||||
|
||||
stencil_to_model_id_map = {
|
||||
"canny": "lllyasviel/control_v11p_sd15_canny",
|
||||
"zoedepth": "lllyasviel/control_v11f1p_sd15_depth",
|
||||
"hed": "lllyasviel/sd-controlnet-hed",
|
||||
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
|
||||
"normal": "lllyasviel/control_v11p_sd15_normalbae",
|
||||
"openpose": "lllyasviel/control_v11p_sd15_openpose",
|
||||
"scribble": "lllyasviel/control_v11p_sd15_scribble",
|
||||
"seg": "lllyasviel/control_v11p_sd15_seg",
|
||||
}
|
||||
|
||||
|
||||
def get_stencil_model_id(use_stencil):
|
||||
if use_stencil in stencil_to_model_id_map:
|
||||
return stencil_to_model_id_map[use_stencil]
|
||||
return None
|
||||
|
||||
|
||||
# Stencil 1. Canny
|
||||
def hint_canny(
|
||||
image: Image.Image,
|
||||
low_threshold=100,
|
||||
high_threshold=200,
|
||||
):
|
||||
with torch.no_grad():
|
||||
input_image = np.array(image)
|
||||
|
||||
if not "canny" in stencil:
|
||||
stencil["canny"] = CannyDetector()
|
||||
detected_map = stencil["canny"](
|
||||
input_image, low_threshold, high_threshold
|
||||
)
|
||||
save_img(detected_map)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
|
||||
# Stencil 2. OpenPose.
|
||||
def hint_openpose(
|
||||
image: Image.Image,
|
||||
):
|
||||
with torch.no_grad():
|
||||
input_image = np.array(image)
|
||||
|
||||
if not "openpose" in stencil:
|
||||
stencil["openpose"] = OpenposeDetector()
|
||||
|
||||
detected_map, _ = stencil["openpose"](input_image)
|
||||
save_img(detected_map)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
|
||||
# Stencil 3. Scribble.
|
||||
def hint_scribble(image: Image.Image):
|
||||
with torch.no_grad():
|
||||
input_image = np.array(image)
|
||||
|
||||
detected_map = np.zeros_like(input_image, dtype=np.uint8)
|
||||
detected_map[np.min(input_image, axis=2) < 127] = 255
|
||||
save_img(detected_map)
|
||||
return detected_map
|
||||
|
||||
|
||||
# Stencil 4. Depth (Only Zoe Preprocessing)
|
||||
def hint_zoedepth(image: Image.Image):
|
||||
with torch.no_grad():
|
||||
input_image = np.array(image)
|
||||
|
||||
if not "depth" in stencil:
|
||||
stencil["depth"] = ZoeDetector()
|
||||
|
||||
detected_map = stencil["depth"](input_image)
|
||||
save_img(detected_map)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
@@ -1,64 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import requests
|
||||
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
remote_model_path = (
|
||||
"https://huggingface.co/lllyasviel/Annotators/resolve/main/ZoeD_M12_N.pt"
|
||||
)
|
||||
|
||||
|
||||
class ZoeDetector:
|
||||
def __init__(self):
|
||||
cwd = Path.cwd()
|
||||
ckpt_path = Path(cwd, "stencil_annotator")
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
modelpath = ckpt_path / "ZoeD_M12_N.pt"
|
||||
|
||||
with requests.get(remote_model_path, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(modelpath, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
model = torch.hub.load(
|
||||
"monorimet/ZoeDepth:torch_update",
|
||||
"ZoeD_N",
|
||||
pretrained=False,
|
||||
force_reload=False,
|
||||
)
|
||||
|
||||
# Hack to fix the ZoeDepth import issue
|
||||
model_keys = model.state_dict().keys()
|
||||
loaded_dict = torch.load(modelpath, map_location=model.device)["model"]
|
||||
loaded_keys = loaded_dict.keys()
|
||||
for key in loaded_keys - model_keys:
|
||||
loaded_dict.pop(key)
|
||||
|
||||
model.load_state_dict(loaded_dict)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
def __call__(self, input_image):
|
||||
assert input_image.ndim == 3
|
||||
image_depth = input_image
|
||||
with torch.no_grad():
|
||||
image_depth = torch.from_numpy(image_depth).float()
|
||||
image_depth = image_depth / 255.0
|
||||
image_depth = rearrange(image_depth, "h w c -> 1 c h w")
|
||||
depth = self.model.infer(image_depth)
|
||||
|
||||
depth = depth[0, 0].cpu().numpy()
|
||||
|
||||
vmin = np.percentile(depth, 2)
|
||||
vmax = np.percentile(depth, 85)
|
||||
|
||||
depth -= vmin
|
||||
depth /= vmax - vmin
|
||||
depth = 1.0 - depth
|
||||
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
return depth_image
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user