mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
167 Commits
20230807.8
...
20231113.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2163488d8 | ||
|
|
54bff4611d | ||
|
|
11510d5111 | ||
|
|
32cab73a29 | ||
|
|
392bade0bf | ||
|
|
91df5f0613 | ||
|
|
df20cf9c8a | ||
|
|
c4a908c3ea | ||
|
|
6285430d8a | ||
|
|
51afe19e20 | ||
|
|
31005bcf73 | ||
|
|
f41ad87ef6 | ||
|
|
d811524a00 | ||
|
|
51e1bd1c5d | ||
|
|
db89b1bdc1 | ||
|
|
2754e2e257 | ||
|
|
ab0e870c43 | ||
|
|
fb30e8c226 | ||
|
|
a07d542400 | ||
|
|
ad55cb696f | ||
|
|
488a172292 | ||
|
|
500c4f2306 | ||
|
|
92b694db4d | ||
|
|
322874f7f9 | ||
|
|
5001db3415 | ||
|
|
71846344a2 | ||
|
|
72e27c96fc | ||
|
|
7963abb8ec | ||
|
|
98244232dd | ||
|
|
679a452139 | ||
|
|
72c0a8abc8 | ||
|
|
ea920f2955 | ||
|
|
486202377a | ||
|
|
0c38c33d0a | ||
|
|
841773fa32 | ||
|
|
0361db46f9 | ||
|
|
a012433ffd | ||
|
|
5061193da3 | ||
|
|
bff48924be | ||
|
|
825b36cbdd | ||
|
|
134441957d | ||
|
|
7cd14fdc47 | ||
|
|
e6cb5cef57 | ||
|
|
66abee8e5b | ||
|
|
4797bb89f5 | ||
|
|
205e57683a | ||
|
|
2866d665ee | ||
|
|
71d25ec5d8 | ||
|
|
202ffff67b | ||
|
|
0b77059628 | ||
|
|
a208302bb9 | ||
|
|
b83d32fafe | ||
|
|
0a618e1863 | ||
|
|
a731eb6ed4 | ||
|
|
2004d16945 | ||
|
|
6e409bfb77 | ||
|
|
77727d149c | ||
|
|
66f6e79d68 | ||
|
|
3b825579a7 | ||
|
|
9f0a421764 | ||
|
|
c28682110c | ||
|
|
caf6cc5d8f | ||
|
|
8614a18474 | ||
|
|
86c1c0c215 | ||
|
|
8bb364bcb8 | ||
|
|
7abddd01ec | ||
|
|
2a451fa0c7 | ||
|
|
9c4610b9da | ||
|
|
a38cc9d216 | ||
|
|
1c382449ec | ||
|
|
7cc9b3f8e8 | ||
|
|
e54517e967 | ||
|
|
326327a799 | ||
|
|
785b65c7b0 | ||
|
|
0d16c81687 | ||
|
|
8dd7850c69 | ||
|
|
e930ba85b4 | ||
|
|
cd732e7a38 | ||
|
|
8e0f8b3227 | ||
|
|
b8210ef796 | ||
|
|
94594542a9 | ||
|
|
82f833e87d | ||
|
|
c9d6870105 | ||
|
|
4fec03a6cc | ||
|
|
9a27f51378 | ||
|
|
ad1a0f35ff | ||
|
|
6773278ec2 | ||
|
|
9a0efffcca | ||
|
|
61c6f153d9 | ||
|
|
effd42e8f5 | ||
|
|
b5fbb1a8a0 | ||
|
|
ded74d09cd | ||
|
|
79267931c1 | ||
|
|
9eceba69b7 | ||
|
|
ca609afb6a | ||
|
|
11bdce9790 | ||
|
|
684943a4a6 | ||
|
|
b817bb8455 | ||
|
|
780f520f02 | ||
|
|
c61b6f8d65 | ||
|
|
c854208d49 | ||
|
|
c5dcfc1f13 | ||
|
|
bde63ee8ae | ||
|
|
9681d494eb | ||
|
|
ede6bf83e2 | ||
|
|
2c2693fb7d | ||
|
|
1d31b2b2c6 | ||
|
|
d2f64eefa3 | ||
|
|
87ae14b6ff | ||
|
|
1ccafa1fc1 | ||
|
|
4c3d8a0a7f | ||
|
|
3601dc7c3b | ||
|
|
671881cf87 | ||
|
|
4e9be6be59 | ||
|
|
9c8cbaf498 | ||
|
|
9e348a114e | ||
|
|
51f90a4d56 | ||
|
|
310d5d0a49 | ||
|
|
9697981004 | ||
|
|
450c231171 | ||
|
|
07f6f4a2f7 | ||
|
|
610813c72f | ||
|
|
8e3860c9e6 | ||
|
|
e37d6720eb | ||
|
|
16160d9a7d | ||
|
|
79075a1a07 | ||
|
|
db990826d3 | ||
|
|
7ee3e4ba5d | ||
|
|
05889a8fe1 | ||
|
|
b87efe7686 | ||
|
|
82b462de3a | ||
|
|
d8f0f7bade | ||
|
|
79bd0b84a1 | ||
|
|
8738571d1e | ||
|
|
a4c354ce54 | ||
|
|
cc53efa89f | ||
|
|
9ae8bc921e | ||
|
|
32eb78f0f9 | ||
|
|
cb509343d9 | ||
|
|
6da391c9b1 | ||
|
|
9dee7ae652 | ||
|
|
343dfd901c | ||
|
|
57260b9c37 | ||
|
|
18e7d2d061 | ||
|
|
51a1009796 | ||
|
|
045c3c3852 | ||
|
|
0139dd58d9 | ||
|
|
c96571855a | ||
|
|
4f61d69d86 | ||
|
|
531d447768 | ||
|
|
16f46f8de9 | ||
|
|
c4723f469f | ||
|
|
d804f45a61 | ||
|
|
d22177f936 | ||
|
|
75e68f02f4 | ||
|
|
4dc9c59611 | ||
|
|
18801dcabc | ||
|
|
3c577f7168 | ||
|
|
f5e4fa6ffe | ||
|
|
48de445325 | ||
|
|
8e90f1b81a | ||
|
|
e8c1203be2 | ||
|
|
e4d7abb519 | ||
|
|
96185c9dc1 | ||
|
|
bc22a81925 | ||
|
|
5203679f1f | ||
|
|
bf073f8f37 |
8
.github/workflows/nightly.yml
vendored
8
.github/workflows/nightly.yml
vendored
@@ -51,11 +51,11 @@ jobs:
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
@@ -104,7 +104,7 @@ jobs:
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
@@ -144,7 +144,7 @@ jobs:
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
|
||||
3
.github/workflows/test-models.yml
vendored
3
.github/workflows/test-models.yml
vendored
@@ -137,7 +137,8 @@ jobs:
|
||||
source shark.venv/bin/activate
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
|
||||
# disabled due to a low-visibility memory issue with pytest on macos.
|
||||
# pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
|
||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -182,7 +182,7 @@ generated_imgs/
|
||||
|
||||
# Custom model related artefacts
|
||||
variants.json
|
||||
models/
|
||||
/models/
|
||||
|
||||
# models folder
|
||||
apps/stable_diffusion/web/models/
|
||||
@@ -193,3 +193,12 @@ stencil_annotator/
|
||||
# For DocuChat
|
||||
apps/language_models/langchain/user_path/
|
||||
db_dir_UserData
|
||||
|
||||
# Embeded browser cache and other
|
||||
apps/stable_diffusion/web/EBWebView/
|
||||
|
||||
# Llama2 tokenizer configs
|
||||
llama2_tokenizer_configs/
|
||||
|
||||
# Webview2 runtime artefacts
|
||||
EBWebView/
|
||||
|
||||
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -1,4 +1,4 @@
|
||||
[submodule "inference/thirdparty/shark-runtime"]
|
||||
path = inference/thirdparty/shark-runtime
|
||||
url =https://github.com/nod-ai/SHARK-Runtime.git
|
||||
url =https://github.com/nod-ai/SRT.git
|
||||
branch = shark-06032022
|
||||
|
||||
14
README.md
14
README.md
@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
|
||||
<summary>Prerequisites - Drivers </summary>
|
||||
|
||||
#### Install your Windows hardware drivers
|
||||
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
|
||||
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
|
||||
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
@@ -170,7 +170,7 @@ python -m pip install --upgrade pip
|
||||
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
|
||||
|
||||
```shell
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
|
||||
### Run shark tank model tests.
|
||||
@@ -254,7 +254,6 @@ if you want to instead incorporate this into a python script, you can pass the `
|
||||
```
|
||||
shark_module = SharkInference(
|
||||
mlir_model,
|
||||
func_name,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
dispatch_benchmarks="all",
|
||||
@@ -297,7 +296,7 @@ torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
|
||||
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input))
|
||||
|
||||
@@ -320,12 +319,17 @@ mhlo_ir = r"""builtin.module {
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
|
||||
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((arg0, arg1))
|
||||
```
|
||||
</details>
|
||||
|
||||
## Examples Using the REST API
|
||||
|
||||
* [Setting up SHARK for use with Blender](./docs/shark_sd_blender.md)
|
||||
* [Setting up SHARK for use with Koboldcpp](./docs/shark_sd_koboldcpp.md)
|
||||
|
||||
## Supported and Validated Models
|
||||
|
||||
SHARK is maintained to support the latest innovations in ML Models:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
"""Load question answering chains."""
|
||||
from __future__ import annotations
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -11,23 +10,34 @@ from typing import (
|
||||
Union,
|
||||
Protocol,
|
||||
)
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from abc import ABC, abstractmethod
|
||||
import langchain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.question_answering import stuff_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.docstore.document import Document
|
||||
from abc import ABC, abstractmethod
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
from pydantic import Extra, Field, root_validator
|
||||
from pydantic import Extra, Field, root_validator, validator
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
@@ -48,6 +58,413 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
return prompt.format(**document_info)
|
||||
|
||||
|
||||
class Chain(Serializable, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default_factory=_get_verbosity
|
||||
) # Whether to print the response text
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys this chain expects."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys this chain expects."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
missing_keys = set(self.output_keys).difference(outputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some output keys: {missing_keys}")
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
input_docs = inputs["input_documents"]
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
|
||||
if "is_first" in inputs.keys() and not inputs["is_first"]:
|
||||
run_manager_ = run_manager
|
||||
input_list = [inputs]
|
||||
stop = None
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {
|
||||
k: inputs[k] for k in self.prompt.input_variables
|
||||
}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager_:
|
||||
run_manager_.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
prompts = prompt_strings
|
||||
callbacks = run_manager_.get_child() if run_manager_ else None
|
||||
tags = None
|
||||
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# If string is passed in directly no errors will be raised but outputs will
|
||||
# not make sense.
|
||||
if not isinstance(prompts, list):
|
||||
raise ValueError(
|
||||
"Argument 'prompts' is expected to be of type List[str], received"
|
||||
f" argument of type {type(prompts)}."
|
||||
)
|
||||
params = self.llm.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
disregard_cache = self.llm.cache is not None and not self.llm.cache
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.llm.callbacks,
|
||||
self.llm.verbose,
|
||||
tags,
|
||||
self.llm.tags,
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.llm.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager_ = callback_manager.on_llm_start(
|
||||
dumpd(self),
|
||||
prompts,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
)
|
||||
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
inputs_ = prompt
|
||||
num_workers = None
|
||||
batch_size = None
|
||||
|
||||
if num_workers is None:
|
||||
if self.llm.pipeline._num_workers is None:
|
||||
num_workers = 0
|
||||
else:
|
||||
num_workers = self.llm.pipeline._num_workers
|
||||
if batch_size is None:
|
||||
if self.llm.pipeline._batch_size is None:
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = self.llm.pipeline._batch_size
|
||||
|
||||
preprocess_params = {}
|
||||
generate_kwargs = {}
|
||||
preprocess_params.update(generate_kwargs)
|
||||
forward_params = generate_kwargs
|
||||
postprocess_params = {}
|
||||
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
|
||||
preprocess_params = {
|
||||
**self.llm.pipeline._preprocess_params,
|
||||
**preprocess_params,
|
||||
}
|
||||
forward_params = {
|
||||
**self.llm.pipeline._forward_params,
|
||||
**forward_params,
|
||||
}
|
||||
postprocess_params = {
|
||||
**self.llm.pipeline._postprocess_params,
|
||||
**postprocess_params,
|
||||
}
|
||||
|
||||
self.llm.pipeline.call_count += 1
|
||||
if (
|
||||
self.llm.pipeline.call_count > 10
|
||||
and self.llm.pipeline.framework == "pt"
|
||||
and self.llm.pipeline.device.type == "cuda"
|
||||
):
|
||||
warnings.warn(
|
||||
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
|
||||
" dataset",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
model_inputs = self.llm.pipeline.preprocess(
|
||||
inputs_, **preprocess_params
|
||||
)
|
||||
model_outputs = self.llm.pipeline.forward(
|
||||
model_inputs, **forward_params
|
||||
)
|
||||
model_outputs["process"] = False
|
||||
return model_outputs
|
||||
output = LLMResult(generations=generations)
|
||||
run_manager_.on_llm_end(output)
|
||||
if run_manager_:
|
||||
output.run = RunInfo(run_id=run_manager_.run_id)
|
||||
response = output
|
||||
|
||||
outputs = [
|
||||
# Get the text of the top generated string.
|
||||
{self.output_key: generation[0].text}
|
||||
for generation in response.generations
|
||||
][0]
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
else:
|
||||
_run_manager = (
|
||||
run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
)
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {
|
||||
k: v for k, v in inputs.items() if k != self.input_key
|
||||
}
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in other_keys.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
inputs["is_first"] = False
|
||||
inputs["input_documents"] = input_docs
|
||||
|
||||
# Call predict on the LLM.
|
||||
output = self.llm_chain(inputs, callbacks=_run_manager.get_child())
|
||||
if "process" in output.keys() and not output["process"]:
|
||||
return output
|
||||
output = output[self.llm_chain.output_key]
|
||||
extra_return_dict = {}
|
||||
extra_return_dict[self.output_key] = output
|
||||
outputs = extra_return_dict
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep outputs."""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(
|
||||
self, inputs: Union[Dict[str, Any], Any]
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(
|
||||
self.memory.memory_variables
|
||||
)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError(
|
||||
"`run` supports only one positional argument."
|
||||
)
|
||||
return self(args[0], callbacks=callbacks, tags=tags)[
|
||||
self.output_keys[0]
|
||||
]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs, callbacks=callbacks, tags=tags)[
|
||||
self.output_keys[0]
|
||||
]
|
||||
|
||||
if not kwargs and not args:
|
||||
raise ValueError(
|
||||
"`run` supported with either positional arguments or keyword arguments,"
|
||||
" but none were provided."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of chain."""
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
_dict = super().dict()
|
||||
_dict["_type"] = self._chain_type
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the chain.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the chain to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain.save(file_path="path/chain.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
chain_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(chain_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
@@ -79,12 +496,6 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def combine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents into a single string."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
@@ -96,13 +507,49 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
output, extra_return_dict = self.combine_docs(
|
||||
docs, callbacks=_run_manager.get_child(), **other_keys
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in other_keys.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
|
||||
# Call predict on the LLM.
|
||||
output, extra_return_dict = (
|
||||
self.llm_chain(inputs, callbacks=_run_manager.get_child())[
|
||||
self.llm_chain.output_key
|
||||
],
|
||||
{},
|
||||
)
|
||||
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Generation(Serializable):
|
||||
"""Output of a single generation."""
|
||||
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
|
||||
generation_info: Optional[Dict[str, Any]] = None
|
||||
"""Raw generation info response from the provider"""
|
||||
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
||||
# TODO: add log probs
|
||||
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
"""Chain to run queries against LLMs.
|
||||
|
||||
@@ -153,21 +600,13 @@ class LLMChain(Chain):
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
response = self.generate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
return self.llm.generate_prompt(
|
||||
prompts, stop = self.prep_prompts([inputs], run_manager=run_manager)
|
||||
response = self.llm.generate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def prep_prompts(
|
||||
self,
|
||||
@@ -223,23 +662,6 @@ class LLMChain(Chain):
|
||||
for generation in response.generations
|
||||
]
|
||||
|
||||
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
Completion from LLM.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return self(kwargs, callbacks=callbacks)[self.output_key]
|
||||
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, Any]]:
|
||||
@@ -350,14 +772,6 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "stuff_documents_chain"
|
||||
|
||||
@@ -1129,7 +1129,7 @@ class Langchain:
|
||||
max_time=max_time,
|
||||
num_return_sequences=num_return_sequences,
|
||||
)
|
||||
outr, extra = run_qa_db(
|
||||
out = run_qa_db(
|
||||
query=instruction,
|
||||
iinput=iinput,
|
||||
context=context,
|
||||
@@ -1171,14 +1171,7 @@ class Langchain:
|
||||
max_chunks=max_chunks,
|
||||
device=self.device,
|
||||
)
|
||||
response = dict(response=outr, sources=extra)
|
||||
if outr or base_model in non_hf_types:
|
||||
# if got no response (e.g. not showing sources and got no sources,
|
||||
# so nothing to give to LLM), then slip through and ask LLM
|
||||
# Or if llama/gptj, then just return since they had no response and can't go down below code path
|
||||
# clear before return, since .then() never done if from API
|
||||
clear_torch_cache()
|
||||
return response
|
||||
return out
|
||||
|
||||
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
||||
global inputs_kwargs_list
|
||||
|
||||
@@ -2554,22 +2554,7 @@ def _run_qa_db(
|
||||
)
|
||||
with context_class_cast(args.device):
|
||||
answer = chain()
|
||||
|
||||
if not use_context:
|
||||
ret = answer["output_text"]
|
||||
extra = ""
|
||||
return ret, extra
|
||||
elif answer is not None:
|
||||
ret, extra = get_sources_answer(
|
||||
query,
|
||||
answer,
|
||||
scores,
|
||||
show_rank,
|
||||
answer_with_sources,
|
||||
verbose=verbose,
|
||||
)
|
||||
return ret, extra
|
||||
return
|
||||
return answer
|
||||
|
||||
|
||||
def get_similarity_chain(
|
||||
|
||||
@@ -3,13 +3,11 @@ from apps.stable_diffusion.src.utils.utils import _compile_module
|
||||
from io import BytesIO
|
||||
import torch_mlir
|
||||
|
||||
from transformers import TextGenerationPipeline
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
|
||||
from stopping import get_stopping
|
||||
from prompter import Prompter, PromptType
|
||||
|
||||
|
||||
from transformers import TextGenerationPipeline
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
from transformers.generation import (
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
@@ -22,7 +20,7 @@ import gc
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
# Brevitas
|
||||
@@ -31,14 +29,8 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
lhs: List[int],
|
||||
rhs: List[int],
|
||||
rhs_scale: List[int],
|
||||
rhs_zero_point: List[int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> List[int]:
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
@@ -47,30 +39,21 @@ def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(
|
||||
lhs_rank_dtype: Tuple[int, int],
|
||||
rhs_rank_dtype: Tuple[int, int],
|
||||
rhs_scale_rank_dtype: Tuple[int, int],
|
||||
rhs_zero_point_rank_dtype: Tuple[int, int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> int:
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
|
||||
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
|
||||
) -> None:
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics,
|
||||
]
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
|
||||
global_device = "cuda"
|
||||
global_precision = "fp16"
|
||||
@@ -246,7 +229,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
@@ -254,7 +237,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
@@ -273,6 +256,11 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
bytecode = save_mlir(
|
||||
bytecode,
|
||||
model_name=f"h2ogpt_{precision}",
|
||||
frontend="torch",
|
||||
)
|
||||
return bytecode
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
@@ -285,7 +273,215 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
return result
|
||||
|
||||
|
||||
h2ogpt_model = H2OGPTSHARKModel()
|
||||
def decode_tokens(tokenizer, res_tokens):
|
||||
for i in range(len(res_tokens)):
|
||||
if type(res_tokens[i]) != int:
|
||||
res_tokens[i] = int(res_tokens[i][0])
|
||||
|
||||
res_str = tokenizer.decode(res_tokens, skip_special_tokens=True)
|
||||
return res_str
|
||||
|
||||
|
||||
def generate_token(h2ogpt_shark_model, model, tokenizer, **generate_kwargs):
|
||||
del generate_kwargs["max_time"]
|
||||
generate_kwargs["input_ids"] = generate_kwargs["input_ids"].to(
|
||||
device=tensor_device
|
||||
)
|
||||
generate_kwargs["attention_mask"] = generate_kwargs["attention_mask"].to(
|
||||
device=tensor_device
|
||||
)
|
||||
truncated_input_ids = []
|
||||
stopping_criteria = generate_kwargs["stopping_criteria"]
|
||||
|
||||
generation_config_ = GenerationConfig.from_model_config(model.config)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
model_kwargs,
|
||||
) = model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs[
|
||||
"output_hidden_states"
|
||||
] = generation_config.output_hidden_states
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length
|
||||
)
|
||||
|
||||
logits_processor = model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
stopping_criteria = model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
logits_warper = model._get_logits_warper(generation_config)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
model_kwargs,
|
||||
) = model._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=model.config.is_encoder_decoder, # False
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id).to(device=tensor_device)
|
||||
if eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0],
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
timesRan = 0
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
print("\n")
|
||||
|
||||
res_tokens = []
|
||||
while True:
|
||||
model_inputs = model.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs
|
||||
)
|
||||
|
||||
outputs = h2ogpt_shark_model.forward(
|
||||
model_inputs["input_ids"], model_inputs["attention_mask"]
|
||||
)
|
||||
|
||||
if args.precision == "fp16":
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = next_token * unfinished_sequences + pad_token_id * (
|
||||
1 - unfinished_sequences
|
||||
)
|
||||
|
||||
input_ids = torch.cat([input_ids, next_token[:, None]], dim=-1)
|
||||
|
||||
model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
truncated_input_ids.append(input_ids[:, 0])
|
||||
input_ids = input_ids[:, 1:]
|
||||
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 1:]
|
||||
|
||||
new_word = tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
res_tokens.append(next_token)
|
||||
if new_word == "<0x0A>":
|
||||
print("\n", end="", flush=True)
|
||||
else:
|
||||
print(f"{new_word}", end=" ", flush=True)
|
||||
|
||||
part_str = decode_tokens(tokenizer, res_tokens)
|
||||
yield part_str
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_token.tile(eos_token_id_tensor.shape[0], 1)
|
||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(
|
||||
input_ids, scores
|
||||
):
|
||||
break
|
||||
timesRan = timesRan + 1
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / timesRan
|
||||
)
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
res_str = decode_tokens(tokenizer, res_tokens)
|
||||
yield res_str
|
||||
|
||||
|
||||
def pad_or_truncate_inputs(
|
||||
@@ -498,233 +694,6 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
)
|
||||
return records
|
||||
|
||||
def generate_new_token(self):
|
||||
model_inputs = self.model.prepare_inputs_for_generation(
|
||||
self.input_ids, **self.model_kwargs
|
||||
)
|
||||
|
||||
outputs = h2ogpt_model.forward(
|
||||
model_inputs["input_ids"], model_inputs["attention_mask"]
|
||||
)
|
||||
|
||||
if args.precision == "fp16":
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = self.logits_processor(
|
||||
self.input_ids, next_token_logits
|
||||
)
|
||||
next_token_scores = self.logits_warper(
|
||||
self.input_ids, next_token_scores
|
||||
)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if self.eos_token_id is not None:
|
||||
if self.pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = (
|
||||
next_token * self.unfinished_sequences
|
||||
+ self.pad_token_id * (1 - self.unfinished_sequences)
|
||||
)
|
||||
|
||||
self.input_ids = torch.cat(
|
||||
[self.input_ids, next_token[:, None]], dim=-1
|
||||
)
|
||||
|
||||
self.model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in self.model_kwargs:
|
||||
attention_mask = self.model_kwargs["attention_mask"]
|
||||
self.model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
self.truncated_input_ids.append(self.input_ids[:, 0])
|
||||
self.input_ids = self.input_ids[:, 1:]
|
||||
self.model_kwargs["attention_mask"] = self.model_kwargs[
|
||||
"attention_mask"
|
||||
][:, 1:]
|
||||
|
||||
return next_token
|
||||
|
||||
def generate_token(self, **generate_kwargs):
|
||||
del generate_kwargs["max_time"]
|
||||
self.truncated_input_ids = []
|
||||
|
||||
generation_config_ = GenerationConfig.from_model_config(
|
||||
self.model.config
|
||||
)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
self.model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
self.stopping_criteria = (
|
||||
self.stopping_criteria
|
||||
if self.stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
self.model_kwargs,
|
||||
) = self.model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, self.model_kwargs
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
|
||||
self.model_kwargs[
|
||||
"output_attentions"
|
||||
] = generation_config.output_attentions
|
||||
self.model_kwargs[
|
||||
"output_hidden_states"
|
||||
] = generation_config.output_hidden_states
|
||||
self.model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
self.input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else self.model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
input_ids_seq_length = self.input_ids.shape[-1]
|
||||
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length
|
||||
)
|
||||
|
||||
self.logits_processor = self.model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
self.stopping_criteria = self.model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=self.stopping_criteria,
|
||||
)
|
||||
|
||||
self.logits_warper = self.model._get_logits_warper(generation_config)
|
||||
|
||||
(
|
||||
self.input_ids,
|
||||
self.model_kwargs,
|
||||
) = self.model._expand_inputs_for_generation(
|
||||
input_ids=self.input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=self.model.config.is_encoder_decoder, # False
|
||||
**self.model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id).to(device=tensor_device)
|
||||
if eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.pad_token_id = generation_config.pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
output_attentions = generation_config.output_attentions # False
|
||||
output_hidden_states = generation_config.output_hidden_states # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
self.scores = (
|
||||
() if (return_dict_in_generate and output_scores) else None
|
||||
)
|
||||
decoder_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
cross_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
decoder_hidden_states = (
|
||||
() if (return_dict_in_generate and output_hidden_states) else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
self.unfinished_sequences = torch.ones(
|
||||
self.input_ids.shape[0],
|
||||
dtype=torch.long,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
|
||||
timesRan = 0
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
print("\n")
|
||||
|
||||
while True:
|
||||
next_token = self.generate_new_token()
|
||||
new_word = self.tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
print(f"{new_word}", end="", flush=True)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if self.eos_token_id_tensor is not None:
|
||||
self.unfinished_sequences = self.unfinished_sequences.mul(
|
||||
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
|
||||
.ne(self.eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if (
|
||||
self.unfinished_sequences.max() == 0
|
||||
or self.stopping_criteria(self.input_ids, self.scores)
|
||||
):
|
||||
break
|
||||
timesRan = timesRan + 1
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / timesRan
|
||||
)
|
||||
)
|
||||
|
||||
self.input_ids = torch.cat(
|
||||
[
|
||||
torch.tensor(self.truncated_input_ids)
|
||||
.to(device=tensor_device)
|
||||
.unsqueeze(dim=0),
|
||||
self.input_ids,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return self.input_ids
|
||||
|
||||
def _forward(self, model_inputs, **generate_kwargs):
|
||||
if self.can_stop:
|
||||
stopping_criteria = get_stopping(
|
||||
@@ -784,19 +753,13 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
input_ids, attention_mask = pad_or_truncate_inputs(
|
||||
input_ids, attention_mask, max_padding_length=max_padding_length
|
||||
)
|
||||
self.stopping_criteria = generate_kwargs["stopping_criteria"]
|
||||
|
||||
generated_sequence = self.generate_token(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
**generate_kwargs,
|
||||
)
|
||||
out_b = generated_sequence.shape[0]
|
||||
generated_sequence = generated_sequence.reshape(
|
||||
in_b, out_b // in_b, *generated_sequence.shape[1:]
|
||||
)
|
||||
return {
|
||||
"generated_sequence": generated_sequence,
|
||||
return_dict = {
|
||||
"model": self.model,
|
||||
"tokenizer": self.tokenizer,
|
||||
"input_ids": input_ids,
|
||||
"prompt_text": prompt_text,
|
||||
"attention_mask": attention_mask,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return_dict = {**return_dict, **generate_kwargs}
|
||||
return return_dict
|
||||
|
||||
@@ -65,8 +65,8 @@ tiktoken==0.4.0
|
||||
openai==0.27.8
|
||||
|
||||
# optional for chat with PDF
|
||||
langchain==0.0.202
|
||||
pypdf==3.12.2
|
||||
langchain==0.0.329
|
||||
pypdf==3.17.0
|
||||
# avoid textract, requires old six
|
||||
#textract==1.6.5
|
||||
|
||||
|
||||
442
apps/language_models/scripts/llama_ir_conversion_utils.py
Normal file
442
apps/language_models/scripts/llama_ir_conversion_utils.py
Normal file
@@ -0,0 +1,442 @@
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
from argparse import RawTextHelpFormatter
|
||||
import re, gc
|
||||
|
||||
"""
|
||||
This script can be used as a standalone utility to convert IRs to dynamic + combine them.
|
||||
Following are the various ways this script can be used :-
|
||||
a. To convert a single Linalg IR to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO FIRST IR>
|
||||
b. To convert two Linalg IRs to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>
|
||||
c. To combine two Linalg IRs into one:
|
||||
--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
d. To convert both IRs into dynamic as well as combine the IRs:
|
||||
--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
|
||||
NOTE: For dynamic you'll also need to provide the following set of flags:-
|
||||
i. For First Llama : --dynamic_input_size (DEFAULT: 19)
|
||||
ii. For Second Llama: --model_name (DEFAULT: llama2_7b)
|
||||
--precision (DEFAULT: 'int4')
|
||||
You may use --save_dynamic to also save the dynamic IR in option d above.
|
||||
Else for option a. and b. the dynamic IR(s) will get saved by default.
|
||||
"""
|
||||
|
||||
|
||||
def combine_mlir_scripts(
|
||||
first_vicuna_mlir,
|
||||
second_vicuna_mlir,
|
||||
output_name,
|
||||
return_ir=True,
|
||||
):
|
||||
print(f"[DEBUG] combining first and second mlir")
|
||||
print(f"[DEBUG] output_name = {output_name}")
|
||||
maps1 = []
|
||||
maps2 = []
|
||||
constants = set()
|
||||
f1 = []
|
||||
f2 = []
|
||||
|
||||
print(f"[DEBUG] processing first vicuna mlir")
|
||||
first_vicuna_mlir = first_vicuna_mlir.splitlines()
|
||||
while first_vicuna_mlir:
|
||||
line = first_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps1.append(line)
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "first_vicuna_forward", line)
|
||||
f1.append(line)
|
||||
f1 = f1[:-1]
|
||||
del first_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps1):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
|
||||
maps1[i] = map_line
|
||||
f1 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
|
||||
for func_line in f1
|
||||
]
|
||||
|
||||
print(f"[DEBUG] processing second vicuna mlir")
|
||||
second_vicuna_mlir = second_vicuna_mlir.splitlines()
|
||||
while second_vicuna_mlir:
|
||||
line = second_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps2.append(line)
|
||||
elif "global_seed" in line:
|
||||
continue
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "second_vicuna_forward", line)
|
||||
f2.append(line)
|
||||
f2 = f2[:-1]
|
||||
del second_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps2):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
|
||||
maps2[i] = map_line
|
||||
f2 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
|
||||
for func_line in f2
|
||||
]
|
||||
|
||||
module_start = 'module attributes {torch.debug_module_name = "_lambda"} {'
|
||||
module_end = "}"
|
||||
|
||||
global_vars = []
|
||||
vnames = []
|
||||
global_var_loading1 = []
|
||||
global_var_loading2 = []
|
||||
|
||||
print(f"[DEBUG] processing constants")
|
||||
counter = 0
|
||||
constants = list(constants)
|
||||
while constants:
|
||||
constant = constants.pop(0)
|
||||
vname, vbody = constant.split("=")
|
||||
vname = re.sub("%", "", vname)
|
||||
vname = vname.strip()
|
||||
vbody = re.sub("arith.constant", "", vbody)
|
||||
vbody = vbody.strip()
|
||||
if len(vbody.split(":")) < 2:
|
||||
print(constant)
|
||||
vdtype = vbody.split(":")[-1].strip()
|
||||
fixed_vdtype = vdtype
|
||||
if "c1_i64" in vname:
|
||||
print(constant)
|
||||
counter += 1
|
||||
if counter == 2:
|
||||
counter = 0
|
||||
print("detected duplicate")
|
||||
continue
|
||||
vnames.append(vname)
|
||||
if "true" not in vname:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
else:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : i1"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
|
||||
new_f1, new_f2 = [], []
|
||||
|
||||
print(f"[DEBUG] processing f1")
|
||||
for line in f1:
|
||||
if "func.func" in line:
|
||||
new_f1.append(line)
|
||||
for global_var in global_var_loading1:
|
||||
new_f1.append(global_var)
|
||||
else:
|
||||
new_f1.append(line)
|
||||
|
||||
print(f"[DEBUG] processing f2")
|
||||
for line in f2:
|
||||
if "func.func" in line:
|
||||
new_f2.append(line)
|
||||
for global_var in global_var_loading2:
|
||||
if (
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
in global_var
|
||||
):
|
||||
print(global_var)
|
||||
new_f2.append(global_var)
|
||||
else:
|
||||
new_f2.append(line)
|
||||
|
||||
f1 = new_f1
|
||||
f2 = new_f2
|
||||
|
||||
del new_f1
|
||||
del new_f2
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
[
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
|
||||
for x in [maps1, maps2, global_vars, f1, f2]
|
||||
]
|
||||
)
|
||||
|
||||
# doing it this way rather than assembling the whole string
|
||||
# to prevent OOM with 64GiB RAM when encoding the file.
|
||||
|
||||
print(f"[DEBUG] Saving mlir to {output_name}")
|
||||
with open(output_name, "w+") as f_:
|
||||
f_.writelines(line + "\n" for line in maps1)
|
||||
f_.writelines(line + "\n" for line in maps2)
|
||||
f_.writelines(line + "\n" for line in [module_start])
|
||||
f_.writelines(line + "\n" for line in global_vars)
|
||||
f_.writelines(line + "\n" for line in f1)
|
||||
f_.writelines(line + "\n" for line in f2)
|
||||
f_.writelines(line + "\n" for line in [module_end])
|
||||
|
||||
del maps1
|
||||
del maps2
|
||||
del module_start
|
||||
del global_vars
|
||||
del f1
|
||||
del f2
|
||||
del module_end
|
||||
gc.collect()
|
||||
|
||||
if return_ir:
|
||||
print(f"[DEBUG] Reading combined mlir back in")
|
||||
with open(output_name, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def write_in_dynamic_inputs0(module, dynamic_input_size):
|
||||
print("[DEBUG] writing dynamic inputs to first vicuna")
|
||||
# Current solution for ensuring mlir files support dynamic inputs
|
||||
# TODO: find a more elegant way to implement this
|
||||
new_lines = []
|
||||
module = module.splitlines()
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
|
||||
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
|
||||
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
|
||||
continue
|
||||
|
||||
new_lines.append(line)
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def write_in_dynamic_inputs1(module, model_name, precision):
|
||||
print("[DEBUG] writing dynamic inputs to second vicuna")
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "c19_i64" in line:
|
||||
line = re.sub("c19_i64", "dim_i64", line)
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
if "x20x" in line or "<20x" in line:
|
||||
line = re.sub("20x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
|
||||
if " 20," in line:
|
||||
line = re.sub(" 20,", " %dimp1,", line)
|
||||
return line
|
||||
|
||||
module = module.splitlines()
|
||||
new_lines = []
|
||||
|
||||
# Using a while loop and the pop method to avoid creating a copy of module
|
||||
if "llama2_13b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x40x?x128x"
|
||||
elif "llama2_70b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x8x?x128x"
|
||||
else:
|
||||
pkv_tensor_shape = "tensor<1x32x?x128x"
|
||||
if precision in ["fp16", "int4", "int8"]:
|
||||
pkv_tensor_shape += "f16>"
|
||||
else:
|
||||
pkv_tensor_shape += "f32>"
|
||||
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append(
|
||||
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
)
|
||||
continue
|
||||
if "%c2 = arith.constant 2 : index" in line:
|
||||
continue
|
||||
if "%c20_i64 = arith.constant 20 : i64" in line:
|
||||
new_lines.append("%c1_i64 = arith.constant 1 : i64")
|
||||
new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
|
||||
new_lines.append(
|
||||
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
|
||||
)
|
||||
continue
|
||||
line = remove_constant_dim(line)
|
||||
new_lines.append(line)
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def save_dynamic_ir(ir_to_save, output_file):
|
||||
if not ir_to_save:
|
||||
return
|
||||
# We only get string output from the dynamic conversion utility.
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
print(ir_to_save)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="llama ir utility",
|
||||
description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n"
|
||||
+ "\tFollowing are the various ways this script can be used :-\n"
|
||||
+ "\t\ta. To convert a single Linalg IR to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO FIRST IR>\n"
|
||||
+ "\t\tb. To convert two Linalg IRs to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\tc. To combine two Linalg IRs into one:\n"
|
||||
+ "\t\t\t--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n"
|
||||
+ "\t\t\t--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n\n"
|
||||
+ "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n"
|
||||
+ "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n"
|
||||
+ "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n"
|
||||
+ "\t\t\t--precision (DEFAULT: 'int4')\n"
|
||||
+ "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n"
|
||||
+ "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
"-p",
|
||||
default="int4",
|
||||
choices=["fp32", "fp16", "int8", "int4"],
|
||||
help="Precision of the concerned IR",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="llama2_7b",
|
||||
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
|
||||
help="Specify which model to run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first_ir_path",
|
||||
default=None,
|
||||
help="path to first llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--second_ir_path",
|
||||
default=None,
|
||||
help="path to second llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic_input_size",
|
||||
type=int,
|
||||
default=19,
|
||||
help="Specify the static input size to replace with dynamic dim.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Save the individual IR(s) after converting to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--combine",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
dynamic = args.dynamic
|
||||
combine = args.combine
|
||||
assert (
|
||||
dynamic or combine
|
||||
), "neither `dynamic` nor `combine` flag is turned on"
|
||||
first_ir_path = args.first_ir_path
|
||||
second_ir_path = args.second_ir_path
|
||||
assert first_ir_path or second_ir_path, "no input ir has been provided"
|
||||
if combine:
|
||||
assert (
|
||||
first_ir_path and second_ir_path
|
||||
), "you will need to provide both IRs to combine"
|
||||
precision = args.precision
|
||||
model_name = args.model_name
|
||||
dynamic_input_size = args.dynamic_input_size
|
||||
save_dynamic = args.save_dynamic
|
||||
|
||||
print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}")
|
||||
print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}")
|
||||
|
||||
if dynamic and not combine:
|
||||
save_dynamic = True
|
||||
|
||||
first_ir = None
|
||||
first_dynamic_ir_name = None
|
||||
second_ir = None
|
||||
second_dynamic_ir_name = None
|
||||
if first_ir_path:
|
||||
first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic"
|
||||
with open(first_ir_path, "r") as f:
|
||||
first_ir = f.read()
|
||||
if second_ir_path:
|
||||
second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic"
|
||||
with open(second_ir_path, "r") as f:
|
||||
second_ir = f.read()
|
||||
if dynamic:
|
||||
first_ir = (
|
||||
write_in_dynamic_inputs0(first_ir, dynamic_input_size)
|
||||
if first_ir
|
||||
else None
|
||||
)
|
||||
second_ir = (
|
||||
write_in_dynamic_inputs1(second_ir, model_name, precision)
|
||||
if second_ir
|
||||
else None
|
||||
)
|
||||
if save_dynamic:
|
||||
save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir")
|
||||
save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir")
|
||||
|
||||
if combine:
|
||||
combine_mlir_scripts(
|
||||
first_ir,
|
||||
second_ir,
|
||||
f"{model_name}_{precision}.mlir",
|
||||
return_ir=False,
|
||||
)
|
||||
@@ -46,6 +46,7 @@ def compile_stableLM(
|
||||
model_vmfb_name,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug=False,
|
||||
):
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
@@ -92,7 +93,7 @@ def compile_stableLM(
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
94
apps/language_models/shark_llama_cli.spec
Normal file
94
apps/language_models/shark_llama_cli.spec
Normal file
@@ -0,0 +1,94 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += copy_metadata('huggingface-hub')
|
||||
datas += copy_metadata('sentencepiece')
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('py-cpuinfo')
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/vicuna.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_llama_cli',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
598
apps/language_models/src/model_wrappers/falcon_sharded_model.py
Normal file
598
apps/language_models/src/model_wrappers/falcon_sharded_model.py
Normal file
@@ -0,0 +1,598 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class WordEmbeddingsLayer(torch.nn.Module):
|
||||
def __init__(self, word_embedding_layer):
|
||||
super().__init__()
|
||||
self.model = word_embedding_layer
|
||||
|
||||
def forward(self, input_ids):
|
||||
output = self.model.forward(input=input_ids)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledWordEmbeddingsLayer(torch.nn.Module):
|
||||
def __init__(self, compiled_word_embedding_layer):
|
||||
super().__init__()
|
||||
self.model = compiled_word_embedding_layer
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids = input_ids.detach().numpy()
|
||||
new_input_ids = self.model("forward", input_ids)
|
||||
new_input_ids = new_input_ids.reshape(
|
||||
[1, new_input_ids.shape[0], new_input_ids.shape[1]]
|
||||
)
|
||||
return torch.tensor(new_input_ids)
|
||||
|
||||
|
||||
class LNFEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, ln_f):
|
||||
super().__init__()
|
||||
self.model = ln_f
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model.forward(input=hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledLNFEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, ln_f):
|
||||
super().__init__()
|
||||
self.model = ln_f
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach().numpy()
|
||||
new_hidden_states = self.model("forward", (hidden_states,))
|
||||
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class LMHeadEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, embedding_layer):
|
||||
super().__init__()
|
||||
self.model = embedding_layer
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model.forward(input=hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, lm_head):
|
||||
super().__init__()
|
||||
self.model = lm_head
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach().numpy()
|
||||
new_hidden_states = self.model("forward", (hidden_states,))
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class DecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
output = self.model.forward(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
return (output[0], output[1][0], output[1][1])
|
||||
|
||||
|
||||
class CompiledDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
alibi: torch.Tensor = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import get_vmfb_from_path
|
||||
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
print("vmfb path for layer: ", self.falcon_vmfb_path)
|
||||
self.model = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=self.device_index,
|
||||
)
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
new_hidden_states, pkv1, pkv2 = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
del self.model
|
||||
|
||||
return tuple(
|
||||
[
|
||||
torch.tensor(new_hidden_states),
|
||||
tuple(
|
||||
[
|
||||
torch.tensor(pkv1),
|
||||
torch.tensor(pkv2),
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class EightDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
self.falcon_variant = falcon_variant
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
new_pkvs = []
|
||||
for layer in self.model:
|
||||
outputs = layer(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
if self.falcon_variant == "7b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
elif self.falcon_variant == "40b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
)
|
||||
elif self.falcon_variant == "180b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported Falcon variant: ", self.falcon_variant
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledEightDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
alibi: torch.Tensor = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import get_vmfb_from_path
|
||||
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
print("vmfb path for layer: ", self.falcon_vmfb_path)
|
||||
self.model = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=self.device_index,
|
||||
)
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
del self.model
|
||||
|
||||
if self.falcon_variant == "7b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
)
|
||||
elif self.falcon_variant == "40b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
)
|
||||
elif self.falcon_variant == "180b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported Falcon variant: ", self.falcon_variant
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class ShardedFalconModel:
|
||||
def __init__(self, model, layers, word_embeddings, ln_f, lm_head):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.model.transformer.h = torch.nn.modules.container.ModuleList(
|
||||
layers
|
||||
)
|
||||
self.model.transformer.word_embeddings = word_embeddings
|
||||
self.model.transformer.ln_f = ln_f
|
||||
self.model.lm_head = lm_head
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
):
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
).logits[:, -1, :]
|
||||
@@ -47,18 +47,15 @@ from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna,
|
||||
SecondVicuna7B,
|
||||
)
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import get_f16_inputs
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
class FirstVicuna(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -18,15 +16,24 @@ class FirstVicuna(torch.nn.Module):
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("First Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float32,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
@@ -40,7 +47,9 @@ class FirstVicuna(torch.nn.Module):
|
||||
def forward(self, input_ids):
|
||||
op = self.model(input_ids=input_ids, use_cache=True)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
token = torch.argmax(op.logits[:, -1, :], dim=1)
|
||||
return_vals.append(token)
|
||||
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
@@ -48,11 +57,12 @@ class FirstVicuna(torch.nn.Module):
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna(torch.nn.Module):
|
||||
class SecondVicuna7B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -64,12 +74,21 @@ class SecondVicuna(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float32,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
@@ -148,8 +167,6 @@ class SecondVicuna(torch.nn.Module):
|
||||
i63,
|
||||
i64,
|
||||
):
|
||||
# input_ids = input_tuple[0]
|
||||
# input_tuple = torch.unbind(pkv, dim=0)
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
@@ -282,6 +299,842 @@ class SecondVicuna(torch.nn.Module):
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
token = torch.argmax(op.logits[:, -1, :], dim=1)
|
||||
return_vals.append(token)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna13B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="int8",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
i0,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
i65,
|
||||
i66,
|
||||
i67,
|
||||
i68,
|
||||
i69,
|
||||
i70,
|
||||
i71,
|
||||
i72,
|
||||
i73,
|
||||
i74,
|
||||
i75,
|
||||
i76,
|
||||
i77,
|
||||
i78,
|
||||
i79,
|
||||
i80,
|
||||
):
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
(
|
||||
i65,
|
||||
i66,
|
||||
),
|
||||
(
|
||||
i67,
|
||||
i68,
|
||||
),
|
||||
(
|
||||
i69,
|
||||
i70,
|
||||
),
|
||||
(
|
||||
i71,
|
||||
i72,
|
||||
),
|
||||
(
|
||||
i73,
|
||||
i74,
|
||||
),
|
||||
(
|
||||
i75,
|
||||
i76,
|
||||
),
|
||||
(
|
||||
i77,
|
||||
i78,
|
||||
),
|
||||
(
|
||||
i79,
|
||||
i80,
|
||||
),
|
||||
)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna70B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
i0,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
i65,
|
||||
i66,
|
||||
i67,
|
||||
i68,
|
||||
i69,
|
||||
i70,
|
||||
i71,
|
||||
i72,
|
||||
i73,
|
||||
i74,
|
||||
i75,
|
||||
i76,
|
||||
i77,
|
||||
i78,
|
||||
i79,
|
||||
i80,
|
||||
i81,
|
||||
i82,
|
||||
i83,
|
||||
i84,
|
||||
i85,
|
||||
i86,
|
||||
i87,
|
||||
i88,
|
||||
i89,
|
||||
i90,
|
||||
i91,
|
||||
i92,
|
||||
i93,
|
||||
i94,
|
||||
i95,
|
||||
i96,
|
||||
i97,
|
||||
i98,
|
||||
i99,
|
||||
i100,
|
||||
i101,
|
||||
i102,
|
||||
i103,
|
||||
i104,
|
||||
i105,
|
||||
i106,
|
||||
i107,
|
||||
i108,
|
||||
i109,
|
||||
i110,
|
||||
i111,
|
||||
i112,
|
||||
i113,
|
||||
i114,
|
||||
i115,
|
||||
i116,
|
||||
i117,
|
||||
i118,
|
||||
i119,
|
||||
i120,
|
||||
i121,
|
||||
i122,
|
||||
i123,
|
||||
i124,
|
||||
i125,
|
||||
i126,
|
||||
i127,
|
||||
i128,
|
||||
i129,
|
||||
i130,
|
||||
i131,
|
||||
i132,
|
||||
i133,
|
||||
i134,
|
||||
i135,
|
||||
i136,
|
||||
i137,
|
||||
i138,
|
||||
i139,
|
||||
i140,
|
||||
i141,
|
||||
i142,
|
||||
i143,
|
||||
i144,
|
||||
i145,
|
||||
i146,
|
||||
i147,
|
||||
i148,
|
||||
i149,
|
||||
i150,
|
||||
i151,
|
||||
i152,
|
||||
i153,
|
||||
i154,
|
||||
i155,
|
||||
i156,
|
||||
i157,
|
||||
i158,
|
||||
i159,
|
||||
i160,
|
||||
):
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
(
|
||||
i65,
|
||||
i66,
|
||||
),
|
||||
(
|
||||
i67,
|
||||
i68,
|
||||
),
|
||||
(
|
||||
i69,
|
||||
i70,
|
||||
),
|
||||
(
|
||||
i71,
|
||||
i72,
|
||||
),
|
||||
(
|
||||
i73,
|
||||
i74,
|
||||
),
|
||||
(
|
||||
i75,
|
||||
i76,
|
||||
),
|
||||
(
|
||||
i77,
|
||||
i78,
|
||||
),
|
||||
(
|
||||
i79,
|
||||
i80,
|
||||
),
|
||||
(
|
||||
i81,
|
||||
i82,
|
||||
),
|
||||
(
|
||||
i83,
|
||||
i84,
|
||||
),
|
||||
(
|
||||
i85,
|
||||
i86,
|
||||
),
|
||||
(
|
||||
i87,
|
||||
i88,
|
||||
),
|
||||
(
|
||||
i89,
|
||||
i90,
|
||||
),
|
||||
(
|
||||
i91,
|
||||
i92,
|
||||
),
|
||||
(
|
||||
i93,
|
||||
i94,
|
||||
),
|
||||
(
|
||||
i95,
|
||||
i96,
|
||||
),
|
||||
(
|
||||
i97,
|
||||
i98,
|
||||
),
|
||||
(
|
||||
i99,
|
||||
i100,
|
||||
),
|
||||
(
|
||||
i101,
|
||||
i102,
|
||||
),
|
||||
(
|
||||
i103,
|
||||
i104,
|
||||
),
|
||||
(
|
||||
i105,
|
||||
i106,
|
||||
),
|
||||
(
|
||||
i107,
|
||||
i108,
|
||||
),
|
||||
(
|
||||
i109,
|
||||
i110,
|
||||
),
|
||||
(
|
||||
i111,
|
||||
i112,
|
||||
),
|
||||
(
|
||||
i113,
|
||||
i114,
|
||||
),
|
||||
(
|
||||
i115,
|
||||
i116,
|
||||
),
|
||||
(
|
||||
i117,
|
||||
i118,
|
||||
),
|
||||
(
|
||||
i119,
|
||||
i120,
|
||||
),
|
||||
(
|
||||
i121,
|
||||
i122,
|
||||
),
|
||||
(
|
||||
i123,
|
||||
i124,
|
||||
),
|
||||
(
|
||||
i125,
|
||||
i126,
|
||||
),
|
||||
(
|
||||
i127,
|
||||
i128,
|
||||
),
|
||||
(
|
||||
i129,
|
||||
i130,
|
||||
),
|
||||
(
|
||||
i131,
|
||||
i132,
|
||||
),
|
||||
(
|
||||
i133,
|
||||
i134,
|
||||
),
|
||||
(
|
||||
i135,
|
||||
i136,
|
||||
),
|
||||
(
|
||||
i137,
|
||||
i138,
|
||||
),
|
||||
(
|
||||
i139,
|
||||
i140,
|
||||
),
|
||||
(
|
||||
i141,
|
||||
i142,
|
||||
),
|
||||
(
|
||||
i143,
|
||||
i144,
|
||||
),
|
||||
(
|
||||
i145,
|
||||
i146,
|
||||
),
|
||||
(
|
||||
i147,
|
||||
i148,
|
||||
),
|
||||
(
|
||||
i149,
|
||||
i150,
|
||||
),
|
||||
(
|
||||
i151,
|
||||
i152,
|
||||
),
|
||||
(
|
||||
i153,
|
||||
i154,
|
||||
),
|
||||
(
|
||||
i155,
|
||||
i156,
|
||||
),
|
||||
(
|
||||
i157,
|
||||
i158,
|
||||
),
|
||||
(
|
||||
i159,
|
||||
i160,
|
||||
),
|
||||
)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
@@ -298,7 +1151,8 @@ class CombinedModel(torch.nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
|
||||
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
|
||||
# NOT using this path for 13B currently, hence using `SecondVicuna7B`.
|
||||
self.second_vicuna = SecondVicuna7B(second_vicuna_model_path)
|
||||
|
||||
def forward(self, input_ids):
|
||||
first_output = self.first_vicuna(input_ids=input_ids)
|
||||
|
||||
1165
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
1165
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,10 @@ from abc import ABC, abstractmethod
|
||||
|
||||
class SharkLLMBase(ABC):
|
||||
def __init__(
|
||||
self, model_name, hf_model_path=None, max_num_tokens=512
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path=None,
|
||||
max_num_tokens=512,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.hf_model_path = hf_model_path
|
||||
|
||||
@@ -1,4 +1,17 @@
|
||||
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
|
||||
from apps.language_models.src.model_wrappers.falcon_sharded_model import (
|
||||
WordEmbeddingsLayer,
|
||||
CompiledWordEmbeddingsLayer,
|
||||
LNFEmbeddingLayer,
|
||||
CompiledLNFEmbeddingLayer,
|
||||
LMHeadEmbeddingLayer,
|
||||
CompiledLMHeadEmbeddingLayer,
|
||||
DecoderLayer,
|
||||
EightDecoderLayer,
|
||||
CompiledDecoderLayer,
|
||||
CompiledEightDecoderLayer,
|
||||
ShardedFalconModel,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
@@ -7,30 +20,39 @@ from io import BytesIO
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stdout
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
|
||||
from transformers.generation import (
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
import copy
|
||||
|
||||
import time
|
||||
import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
import argparse
|
||||
import gc
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="falcon runner",
|
||||
description="runs a falcon model",
|
||||
)
|
||||
|
||||
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
|
||||
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compressed",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Do the compression of sharded layers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
@@ -49,7 +71,7 @@ parser.add_argument(
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=False,
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
@@ -59,32 +81,63 @@ parser.add_argument(
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model in cli mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication token for falcon-180B model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--sharded",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model as sharded",
|
||||
)
|
||||
|
||||
|
||||
class Falcon(SharkLLMBase):
|
||||
class ShardedFalcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path,
|
||||
hf_model_path="tiiuae/falcon-7b-instruct",
|
||||
hf_auth_token: str = None,
|
||||
max_num_tokens=150,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print("hf_model_path: ", self.hf_model_path)
|
||||
|
||||
if (
|
||||
"180b" in self.model_name
|
||||
and precision != "int4"
|
||||
and hf_auth_token == None
|
||||
):
|
||||
raise ValueError(
|
||||
""" HF auth token required for falcon-180b. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.falcon_vmfb_path = falcon_vmfb_path
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile(compressed=args.compressed)
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path, trust_remote_code=True
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
token=self.hf_auth_token,
|
||||
)
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token_id = 11
|
||||
@@ -92,13 +145,546 @@ class Falcon(SharkLLMBase):
|
||||
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
if self.precision == "int4":
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
if self.precision == "int4":
|
||||
falcon_model = falcon_model.to(torch.float32)
|
||||
return falcon_model
|
||||
|
||||
def compile_falcon(self):
|
||||
def compile_layer(
|
||||
self, layer, falconCompileInput, layer_id, device_idx=None
|
||||
):
|
||||
self.falcon_mlir_path = Path(
|
||||
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir"
|
||||
)
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
print(f"[DEBUG] Trying to download vmfb from shark_tank")
|
||||
download_public_file(
|
||||
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/vmfb/"
|
||||
+ str(self.falcon_vmfb_path),
|
||||
self.falcon_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=device_idx,
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb, device_idx
|
||||
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
if args.load_mlir_from_shark_tank:
|
||||
# Downloading MLIR from shark_tank
|
||||
print(f"[DEBUG] Trying to download mlir from shark_tank")
|
||||
download_public_file(
|
||||
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/mlir/"
|
||||
+ str(self.falcon_mlir_path),
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(
|
||||
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] generating MLIR locally")
|
||||
if layer_id == "word_embeddings":
|
||||
f16_input_mask = [False]
|
||||
elif layer_id in ["ln_f", "lm_head"]:
|
||||
f16_input_mask = [True]
|
||||
elif "_" in layer_id or type(layer_id) == int:
|
||||
f16_input_mask = [True, False]
|
||||
else:
|
||||
raise ValueError("Unsupported layer: ", layer_id)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
layer,
|
||||
falconCompileInput,
|
||||
is_f16=True,
|
||||
f16_input_mask=f16_input_mask,
|
||||
mlir_type="torchscript",
|
||||
is_gptq=True,
|
||||
)
|
||||
del layer
|
||||
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
falconCompileInput,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
f_ = open(self.falcon_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
|
||||
f_.close()
|
||||
del bytecode
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=self.falcon_mlir_path,
|
||||
device=self.device,
|
||||
mlir_dialect="linalg",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
self.falcon_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
+ [
|
||||
"--iree-llvmcpu-use-fast-min-max-ops",
|
||||
]
|
||||
if self.precision == "int4"
|
||||
else [],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module, device_idx
|
||||
|
||||
def compile(self, compressed=False):
|
||||
sample_input_ids = torch.zeros([100], dtype=torch.int64)
|
||||
sample_attention_mask = torch.zeros(
|
||||
[1, 1, 100, 100], dtype=torch.float32
|
||||
)
|
||||
num_group_layers = 1
|
||||
if "7b" in self.model_name:
|
||||
num_in_features = 4544
|
||||
if compressed:
|
||||
num_group_layers = 8
|
||||
elif "40b" in self.model_name:
|
||||
num_in_features = 8192
|
||||
if compressed:
|
||||
num_group_layers = 15
|
||||
else:
|
||||
num_in_features = 14848
|
||||
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
|
||||
if compressed:
|
||||
num_group_layers = 20
|
||||
|
||||
sample_hidden_states = torch.zeros(
|
||||
[1, 100, num_in_features], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Determine number of available devices
|
||||
num_devices = 1
|
||||
if self.device == "rocm":
|
||||
import iree.runtime as ireert
|
||||
|
||||
haldriver = ireert.get_driver(self.device)
|
||||
num_devices = len(haldriver.query_available_devices())
|
||||
|
||||
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
|
||||
print("Compiling Layer lm_head")
|
||||
shark_lm_head, _ = self.compile_layer(
|
||||
lm_head,
|
||||
[sample_hidden_states],
|
||||
"lm_head",
|
||||
device_idx=0 % num_devices if self.device == "rocm" else None,
|
||||
)
|
||||
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
|
||||
|
||||
word_embedding = WordEmbeddingsLayer(
|
||||
self.src_model.transformer.word_embeddings
|
||||
)
|
||||
print("Compiling Layer word_embeddings")
|
||||
shark_word_embedding, _ = self.compile_layer(
|
||||
word_embedding,
|
||||
[sample_input_ids],
|
||||
"word_embeddings",
|
||||
device_idx=1 % num_devices if self.device == "rocm" else None,
|
||||
)
|
||||
shark_word_embedding = CompiledWordEmbeddingsLayer(
|
||||
shark_word_embedding
|
||||
)
|
||||
|
||||
ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f)
|
||||
print("Compiling Layer ln_f")
|
||||
shark_ln_f, _ = self.compile_layer(
|
||||
ln_f,
|
||||
[sample_hidden_states],
|
||||
"ln_f",
|
||||
device_idx=2 % num_devices if self.device == "rocm" else None,
|
||||
)
|
||||
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
|
||||
|
||||
shark_layers = []
|
||||
for i in range(
|
||||
int(len(self.src_model.transformer.h) / num_group_layers)
|
||||
):
|
||||
device_idx = i % num_devices if self.device == "rocm" else None
|
||||
layer_id = i
|
||||
pytorch_class = DecoderLayer
|
||||
compiled_class = CompiledDecoderLayer
|
||||
if compressed:
|
||||
layer_id = (
|
||||
str(i * num_group_layers)
|
||||
+ "_"
|
||||
+ str((i + 1) * num_group_layers)
|
||||
)
|
||||
pytorch_class = EightDecoderLayer
|
||||
compiled_class = CompiledEightDecoderLayer
|
||||
|
||||
print("Compiling Layer {}".format(layer_id))
|
||||
if compressed:
|
||||
layer_i = self.src_model.transformer.h[
|
||||
i * num_group_layers : (i + 1) * num_group_layers
|
||||
]
|
||||
else:
|
||||
layer_i = self.src_model.transformer.h[i]
|
||||
|
||||
pytorch_layer_i = pytorch_class(
|
||||
layer_i, args.falcon_variant_to_use
|
||||
)
|
||||
shark_module, device_idx = self.compile_layer(
|
||||
pytorch_layer_i,
|
||||
[sample_hidden_states, sample_attention_mask],
|
||||
layer_id,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
del shark_module
|
||||
shark_layer_i = compiled_class(
|
||||
layer_id,
|
||||
device_idx,
|
||||
args.falcon_variant_to_use,
|
||||
self.device,
|
||||
self.precision,
|
||||
)
|
||||
shark_layers.append(shark_layer_i)
|
||||
|
||||
sharded_model = ShardedFalconModel(
|
||||
self.src_model,
|
||||
shark_layers,
|
||||
shark_word_embedding,
|
||||
shark_ln_f,
|
||||
shark_lm_head,
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
def generate(self, prompt):
|
||||
model_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.max_padding_length,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
model_inputs["prompt_text"] = prompt
|
||||
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs.get("attention_mask", None)
|
||||
|
||||
# Allow empty prompts
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
|
||||
generate_kwargs = {
|
||||
"max_length": self.max_num_tokens,
|
||||
"do_sample": True,
|
||||
"top_k": 10,
|
||||
"num_return_sequences": 1,
|
||||
"eos_token_id": 11,
|
||||
}
|
||||
generate_kwargs["input_ids"] = input_ids
|
||||
generate_kwargs["attention_mask"] = attention_mask
|
||||
generation_config_ = GenerationConfig.from_model_config(
|
||||
self.src_model.config
|
||||
)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
model_kwargs,
|
||||
) = self.src_model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs[
|
||||
"output_hidden_states"
|
||||
] = generation_config.output_hidden_states
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
self.logits_processor = self.src_model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids.shape[-1],
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
self.stopping_criteria = self.src_model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
self.logits_warper = self.src_model._get_logits_warper(
|
||||
generation_config
|
||||
)
|
||||
|
||||
(
|
||||
self.input_ids,
|
||||
self.model_kwargs,
|
||||
) = self.src_model._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id) if eos_token_id is not None else None
|
||||
)
|
||||
|
||||
self.pad_token_id = generation_config.pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
self.scores = (
|
||||
() if (return_dict_in_generate and output_scores) else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
self.unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
|
||||
all_text = prompt
|
||||
|
||||
start = time.time()
|
||||
count = 0
|
||||
for i in range(self.max_num_tokens - 1):
|
||||
count = count + 1
|
||||
|
||||
next_token = self.generate_new_token()
|
||||
new_word = self.tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
all_text = all_text + new_word
|
||||
|
||||
print(f"{new_word}", end="", flush=True)
|
||||
print(f"{all_text}", end="", flush=True)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if self.eos_token_id_tensor is not None:
|
||||
self.unfinished_sequences = self.unfinished_sequences.mul(
|
||||
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
|
||||
.ne(self.eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if (
|
||||
self.unfinished_sequences.max() == 0
|
||||
or self.stopping_criteria(input_ids, self.scores)
|
||||
):
|
||||
break
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / count
|
||||
)
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return all_text
|
||||
|
||||
def generate_new_token(self):
|
||||
model_inputs = self.src_model.prepare_inputs_for_generation(
|
||||
self.input_ids, **self.model_kwargs
|
||||
)
|
||||
outputs = self.shark_model.forward(
|
||||
input_ids=model_inputs["input_ids"],
|
||||
attention_mask=model_inputs["attention_mask"],
|
||||
)
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = self.logits_processor(
|
||||
self.input_ids, next_token_logits
|
||||
)
|
||||
next_token_scores = self.logits_warper(
|
||||
self.input_ids, next_token_scores
|
||||
)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if self.eos_token_id is not None:
|
||||
if self.pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = (
|
||||
next_token * self.unfinished_sequences
|
||||
+ self.pad_token_id * (1 - self.unfinished_sequences)
|
||||
)
|
||||
|
||||
self.input_ids = torch.cat(
|
||||
[self.input_ids, next_token[:, None]], dim=-1
|
||||
)
|
||||
|
||||
self.model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in self.model_kwargs:
|
||||
attention_mask = self.model_kwargs["attention_mask"]
|
||||
self.model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
self.input_ids = self.input_ids[:, 1:]
|
||||
self.model_kwargs["attention_mask"] = self.model_kwargs[
|
||||
"attention_mask"
|
||||
][:, 1:]
|
||||
|
||||
return next_token
|
||||
|
||||
|
||||
class UnshardedFalcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="tiiuae/falcon-7b-instruct",
|
||||
hf_auth_token: str = "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",
|
||||
max_num_tokens=150,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print("hf_model_path: ", self.hf_model_path)
|
||||
|
||||
if "180b" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
""" HF auth token required for falcon-180b. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.falcon_vmfb_path = falcon_vmfb_path
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
token=self.hf_auth_token,
|
||||
)
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token_id = 11
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
if self.precision == "int4":
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
if self.precision == "int4":
|
||||
falcon_model = falcon_model.to(torch.float32)
|
||||
return falcon_model
|
||||
|
||||
def compile(self):
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
@@ -120,37 +706,37 @@ class Falcon(SharkLLMBase):
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
|
||||
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
if args.load_mlir_from_shark_tank:
|
||||
# Downloading MLIR from shark_tank
|
||||
print(f"[DEBUG] Trying to download mlir from shark_tank")
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(
|
||||
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
mlir_generated = True
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] generating MLIR locally")
|
||||
compilation_input_ids = torch.randint(
|
||||
low=1, high=10000, size=(1, 100)
|
||||
)
|
||||
@@ -167,9 +753,10 @@ class Falcon(SharkLLMBase):
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
falconCompileInput,
|
||||
is_f16=self.precision == "fp16",
|
||||
is_f16=self.precision in ["fp16", "int4"],
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
is_gptq=self.precision == "int4",
|
||||
)
|
||||
del model
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
@@ -189,35 +776,37 @@ class Falcon(SharkLLMBase):
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
print(f"[DEBUG] writing mlir to file")
|
||||
with open(f"{self.model_name}.mlir", "wb") as f_:
|
||||
with redirect_stdout(f_):
|
||||
print(module.operation.get_asm())
|
||||
f_ = open(self.falcon_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
|
||||
f_.close()
|
||||
del bytecode
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module=self.falcon_mlir_path,
|
||||
device=self.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
self.falcon_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-spirv-index-bits=64",
|
||||
],
|
||||
]
|
||||
+ [
|
||||
"--iree-llvmcpu-use-fast-min-max-ops",
|
||||
]
|
||||
if self.precision == "int4"
|
||||
else [],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
|
||||
def compile(self):
|
||||
falcon_shark_model = self.compile_falcon()
|
||||
return falcon_shark_model
|
||||
|
||||
def generate(self, prompt):
|
||||
model_inputs = self.tokenizer(
|
||||
prompt,
|
||||
@@ -345,7 +934,11 @@ class Falcon(SharkLLMBase):
|
||||
|
||||
all_text = prompt
|
||||
|
||||
start = time.time()
|
||||
count = 0
|
||||
for i in range(self.max_num_tokens - 1):
|
||||
count = count + 1
|
||||
|
||||
next_token = self.generate_new_token()
|
||||
new_word = self.tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
@@ -372,6 +965,13 @@ class Falcon(SharkLLMBase):
|
||||
):
|
||||
break
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / count
|
||||
)
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@@ -387,7 +987,7 @@ class Falcon(SharkLLMBase):
|
||||
(model_inputs["input_ids"], model_inputs["attention_mask"]),
|
||||
)
|
||||
)
|
||||
if self.precision == "fp16":
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
@@ -466,18 +1066,39 @@ if __name__ == "__main__":
|
||||
else Path(args.falcon_vmfb_path)
|
||||
)
|
||||
|
||||
falcon = Falcon(
|
||||
"falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path="tiiuae/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
falcon_vmfb_path=falcon_vmfb_path,
|
||||
)
|
||||
if args.precision == "int4":
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"TheBloke/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct-GPTQ"
|
||||
)
|
||||
else:
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "tiiuae/falcon-180B-chat"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
|
||||
)
|
||||
|
||||
import gc
|
||||
if not args.sharded:
|
||||
falcon = UnshardedFalcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
falcon_vmfb_path=falcon_vmfb_path,
|
||||
)
|
||||
else:
|
||||
falcon = ShardedFalcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
)
|
||||
|
||||
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
|
||||
continue_execution = True
|
||||
@@ -497,7 +1118,11 @@ if __name__ == "__main__":
|
||||
prompt = input("Please enter the prompt text: ")
|
||||
print("\nPrompt Text: ", prompt)
|
||||
|
||||
res_str = falcon.generate(prompt)
|
||||
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
|
||||
User: {prompt}
|
||||
Assistant:"""
|
||||
|
||||
res_str = falcon.generate(prompt_template)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
print(
|
||||
|
||||
@@ -126,7 +126,7 @@ def is_url(input_url):
|
||||
import os
|
||||
import tempfile
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
@@ -136,7 +136,8 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
@@ -145,20 +146,21 @@ def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
@@ -176,7 +178,7 @@ def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
|
||||
|
||||
def compile_module(
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[]
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[], debug=False,
|
||||
):
|
||||
if generate_vmfb:
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
|
||||
@@ -188,7 +190,7 @@ def compile_module(
|
||||
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_model_name, extra_args
|
||||
os.getcwd(), extended_model_name, extra_args, debug=debug
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -197,7 +199,7 @@ def compile_module(
|
||||
|
||||
|
||||
def compile_int_precision(
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name, debug=False
|
||||
):
|
||||
torchscript_module = import_with_fx(
|
||||
model,
|
||||
@@ -209,7 +211,7 @@ def compile_int_precision(
|
||||
torchscript_module,
|
||||
inputs,
|
||||
output_type="torch",
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
@@ -217,7 +219,7 @@ def compile_int_precision(
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
@@ -233,6 +235,12 @@ def compile_int_precision(
|
||||
mlir_module = BytesIO(mlir_module)
|
||||
bytecode = mlir_module.read()
|
||||
print(f"Elided IR written for {extended_model_name}")
|
||||
bytecode = save_mlir(
|
||||
bytecode,
|
||||
model_name=extended_model_name,
|
||||
frontend="torch",
|
||||
dir=os.getcwd(),
|
||||
)
|
||||
return bytecode
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
@@ -249,6 +257,7 @@ def compile_int_precision(
|
||||
extended_model_name=extended_model_name,
|
||||
generate_vmfb=generate_vmfb,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
),
|
||||
bytecode,
|
||||
)
|
||||
@@ -292,6 +301,7 @@ def shark_compile_through_fx_int(
|
||||
device,
|
||||
generate_or_load_vmfb,
|
||||
extended_model_name,
|
||||
debug,
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
|
||||
@@ -32,11 +32,13 @@ class SharkStableLM(SharkLLMBase):
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug="False",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
@@ -111,7 +113,7 @@ class SharkStableLM(SharkLLMBase):
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from shark.shark_downloader import download_public_file
|
||||
|
||||
# expects a Path / str as arg
|
||||
# returns None if path not found or SharkInference module
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
|
||||
if not isinstance(vmfb_path, Path):
|
||||
vmfb_path = Path(vmfb_path)
|
||||
|
||||
@@ -20,7 +20,7 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
print("Device from get_vmfb_from_path - ", device)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect=mlir_dialect
|
||||
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
@@ -28,7 +28,13 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
|
||||
|
||||
def get_vmfb_from_config(
|
||||
shark_container, model, precision, device, vmfb_path, padding=None
|
||||
shark_container,
|
||||
model,
|
||||
precision,
|
||||
device,
|
||||
vmfb_path,
|
||||
padding=None,
|
||||
device_id=None,
|
||||
):
|
||||
vmfb_url = (
|
||||
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
|
||||
@@ -37,4 +43,6 @@ def get_vmfb_from_config(
|
||||
vmfb_url = vmfb_url + f"_{padding}"
|
||||
vmfb_url = vmfb_url + ".vmfb"
|
||||
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
|
||||
return get_vmfb_from_path(vmfb_path, device, "tm_tensor")
|
||||
return get_vmfb_from_path(
|
||||
vmfb_path, device, "tm_tensor", device_id=device_id
|
||||
)
|
||||
|
||||
@@ -7,16 +7,16 @@ Compile Commands FP32/FP16:
|
||||
|
||||
```shell
|
||||
Vulkan AMD:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
|
||||
# use –iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
|
||||
|
||||
CUDA NVIDIA:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
CPU:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu /path/to/input/mlir -o /path/to/output/vmfb
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
||||
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
@@ -287,7 +287,7 @@ def lora_train(
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
@@ -15,8 +15,8 @@ pathex = [
|
||||
|
||||
# datafiles for pyinstaller
|
||||
datas = []
|
||||
datas += collect_data_files("torch")
|
||||
datas += copy_metadata("torch")
|
||||
datas += copy_metadata("tokenizers")
|
||||
datas += copy_metadata("tqdm")
|
||||
datas += copy_metadata("regex")
|
||||
datas += copy_metadata("requests")
|
||||
@@ -30,20 +30,21 @@ datas += copy_metadata("safetensors")
|
||||
datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files("diffusers")
|
||||
datas += collect_data_files("transformers")
|
||||
datas += collect_data_files("pytorch_lightning")
|
||||
datas += collect_data_files("opencv_python")
|
||||
datas += collect_data_files("skimage")
|
||||
datas += collect_data_files("gradio")
|
||||
datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree")
|
||||
datas += collect_data_files("google_cloud_storage")
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
@@ -51,6 +52,8 @@ datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
datas += collect_data_files("cv2")
|
||||
datas += collect_data_files("einops")
|
||||
datas += [
|
||||
("src/utils/resources/prompts.json", "resources"),
|
||||
("src/utils/resources/model_db.json", "resources"),
|
||||
@@ -73,6 +76,13 @@ datas += [
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("transformers") if "tests" not in x
|
||||
x for x in collect_submodules("diffusers") if "tests" not in x
|
||||
]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
for x in collect_submodules("transformers")
|
||||
if not any(kw in x for kw in blacklist)
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]
|
||||
|
||||
@@ -8,6 +8,7 @@ import traceback
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import requests
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_opt_flags,
|
||||
@@ -16,6 +17,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
preprocessCKPT,
|
||||
convert_original_vae,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
get_civitai_checkpoint,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
@@ -94,21 +96,19 @@ class SharkifyStableDiffusionModel:
|
||||
self.height = height // 8
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
self.custom_weights = custom_weights
|
||||
self.custom_weights = custom_weights.strip()
|
||||
self.use_quantize = use_quantize
|
||||
if custom_weights != "":
|
||||
if "civitai" in custom_weights:
|
||||
weights_id = custom_weights.split("/")[-1]
|
||||
# TODO: use model name and identify file type by civitai rest api
|
||||
weights_path = (
|
||||
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
|
||||
)
|
||||
if not os.path.isfile(weights_path):
|
||||
subprocess.run(
|
||||
["wget", custom_weights, "-O", weights_path]
|
||||
)
|
||||
if custom_weights.startswith("https://civitai.com/api/"):
|
||||
# download the checkpoint from civitai if we don't already have it
|
||||
weights_path = get_civitai_checkpoint(custom_weights)
|
||||
|
||||
# act as if we were given the local file as custom_weights originally
|
||||
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
|
||||
self.custom_weights = weights_path
|
||||
|
||||
# needed to ensure webui sets the correct model name metadata
|
||||
args.ckpt_loc = weights_path
|
||||
else:
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
@@ -116,6 +116,7 @@ class SharkifyStableDiffusionModel:
|
||||
custom_weights = get_path_to_diffusers_checkpoint(
|
||||
custom_weights
|
||||
)
|
||||
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
# TODO: remove the following line when stable-diffusion-2-1 works
|
||||
if self.model_id == "stabilityai/stable-diffusion-2-1":
|
||||
@@ -177,9 +178,11 @@ class SharkifyStableDiffusionModel:
|
||||
"unet",
|
||||
"unet512",
|
||||
"stencil_unet",
|
||||
"stencil_unet_512",
|
||||
"vae",
|
||||
"vae_encode",
|
||||
"stencil_adaptor",
|
||||
"stencil_adaptor_512",
|
||||
]
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
@@ -339,7 +342,7 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_vae, vae_mlir
|
||||
|
||||
def get_controlled_unet(self):
|
||||
def get_controlled_unet(self, use_large=False):
|
||||
class ControlledUnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -415,6 +418,16 @@ class SharkifyStableDiffusionModel:
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
model_name = "stencil_unet"
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (
|
||||
inputs[:2]
|
||||
+ (torch.nn.functional.pad(inputs[2], pad),)
|
||||
+ inputs[3:]
|
||||
)
|
||||
model_name = "stencil_unet_512"
|
||||
input_mask = [
|
||||
True,
|
||||
True,
|
||||
@@ -437,19 +450,19 @@ class SharkifyStableDiffusionModel:
|
||||
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["stencil_unet"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="stencil_unet",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_controlled_unet, controlled_unet_mlir
|
||||
|
||||
def get_control_net(self):
|
||||
def get_control_net(self, use_large=False):
|
||||
class StencilControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.use_stencil, low_cpu_mem_usage=False
|
||||
@@ -497,17 +510,34 @@ class SharkifyStableDiffusionModel:
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["stencil_adaptor"])
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3],
|
||||
)
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
|
||||
)
|
||||
else:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor"]
|
||||
)
|
||||
input_mask = [True, True, True, True]
|
||||
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
|
||||
shark_cnet, cnet_mlir = compile_through_fx(
|
||||
scnet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["stencil_adaptor"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="stencil_adaptor",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
@@ -681,8 +711,11 @@ class SharkifyStableDiffusionModel:
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
|
||||
save_dir = ""
|
||||
if self.debug:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["clip"]
|
||||
)
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
@@ -748,7 +781,7 @@ class SharkifyStableDiffusionModel:
|
||||
else:
|
||||
return self.get_unet(use_large=use_large)
|
||||
else:
|
||||
return self.get_controlled_unet()
|
||||
return self.get_controlled_unet(use_large=use_large)
|
||||
|
||||
def vae_encode(self):
|
||||
try:
|
||||
@@ -847,12 +880,14 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def controlnet(self):
|
||||
def controlnet(self, use_large=False):
|
||||
try:
|
||||
self.inputs["stencil_adaptor"] = self.get_input_info_for(
|
||||
base_models["stencil_adaptor"]
|
||||
)
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
|
||||
use_large=use_large
|
||||
)
|
||||
|
||||
check_compilation(compiled_stencil_adaptor, "Stencil")
|
||||
if self.return_mlir:
|
||||
|
||||
@@ -29,6 +29,10 @@ from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
|
||||
class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
@@ -84,13 +88,21 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
num_inference_steps,
|
||||
strength,
|
||||
dtype,
|
||||
resample_type,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre process image
|
||||
image = image.resize((width, height))
|
||||
# Pre-process image
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
image = image.resize((width, height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
@@ -147,6 +159,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
resample_type,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -186,6 +199,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
|
||||
@@ -58,6 +58,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.controlnet = None
|
||||
self.controlnet_512 = None
|
||||
|
||||
def load_controlnet(self):
|
||||
if self.controlnet is not None:
|
||||
@@ -68,6 +69,15 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
del self.controlnet
|
||||
self.controlnet = None
|
||||
|
||||
def load_controlnet_512(self):
|
||||
if self.controlnet_512 is not None:
|
||||
return
|
||||
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
|
||||
|
||||
def unload_controlnet_512(self):
|
||||
del self.controlnet_512
|
||||
self.controlnet_512 = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
@@ -111,8 +121,12 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.load_unet()
|
||||
self.load_controlnet()
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
self.load_controlnet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
self.load_controlnet_512()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype)
|
||||
@@ -135,43 +149,82 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
).to(dtype)
|
||||
else:
|
||||
latent_model_input_1 = latent_model_input
|
||||
control = self.controlnet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
control = self.controlnet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
control = self.controlnet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
timestep = timestep.detach().numpy()
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
print(self.unet_512)
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
@@ -191,7 +244,9 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
self.unload_controlnet()
|
||||
self.unload_controlnet_512()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -218,6 +273,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
resample_type,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# TODO: 1. Change `num_images_per_prompt`.
|
||||
|
||||
@@ -84,9 +84,6 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
|
||||
@@ -41,3 +41,8 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
resize_stencil,
|
||||
_compile_module,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint
|
||||
from apps.stable_diffusion.src.utils.resamplers import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
42
apps/stable_diffusion/src/utils/civitai.py
Normal file
42
apps/stable_diffusion/src/utils/civitai.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import re
|
||||
import requests
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_civitai_checkpoint(url: str):
|
||||
with requests.get(url, allow_redirects=True, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# civitai api returns the filename in the content disposition
|
||||
base_filename = re.findall(
|
||||
'"([^"]*)"', response.headers["Content-Disposition"]
|
||||
)[0]
|
||||
destination_path = (
|
||||
Path.cwd() / (args.ckpt_dir or "models") / base_filename
|
||||
)
|
||||
|
||||
# we don't have this model downloaded yet
|
||||
if not destination_path.is_file():
|
||||
print(
|
||||
f"downloading civitai model from {url} to {destination_path}"
|
||||
)
|
||||
|
||||
size = int(response.headers["content-length"], 0)
|
||||
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=65536):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
# we already have this model downloaded
|
||||
else:
|
||||
print(f"civitai model already downloaded to {destination_path}")
|
||||
|
||||
response.close()
|
||||
return destination_path.as_posix()
|
||||
12
apps/stable_diffusion/src/utils/resamplers.py
Normal file
12
apps/stable_diffusion/src/utils/resamplers.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import PIL.Image as Image
|
||||
|
||||
resamplers = {
|
||||
"Lanczos": Image.Resampling.LANCZOS,
|
||||
"Nearest Neighbor": Image.Resampling.NEAREST,
|
||||
"Bilinear": Image.Resampling.BILINEAR,
|
||||
"Bicubic": Image.Resampling.BICUBIC,
|
||||
"Hamming": Image.Resampling.HAMMING,
|
||||
"Box": Image.Resampling.BOX,
|
||||
}
|
||||
|
||||
resampler_list = resamplers.keys()
|
||||
@@ -11,12 +11,12 @@
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,7 @@
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
},
|
||||
@@ -37,7 +37,7 @@
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -45,12 +45,12 @@
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ def load_lower_configs(base_model_id=None):
|
||||
spec = spec.split("-")[0]
|
||||
|
||||
if args.annotation_model == "vae":
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
if not spec or spec in ["sm_80"]:
|
||||
config_name = (
|
||||
f"{args.annotation_model}_{args.precision}_{device}.json"
|
||||
)
|
||||
@@ -158,9 +158,9 @@ def load_lower_configs(base_model_id=None):
|
||||
f"{spec}.json"
|
||||
)
|
||||
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading lowering config file from ", lowering_config_dir)
|
||||
full_gs_url = config_bucket + config_name
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
return lowering_config_dir
|
||||
|
||||
@@ -203,8 +203,8 @@ def dump_after_mlir(input_mlir, use_winograd):
|
||||
if use_winograd:
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
|
||||
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
|
||||
"iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
@@ -212,8 +212,8 @@ def dump_after_mlir(input_mlir, use_winograd):
|
||||
else:
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
|
||||
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
)
|
||||
|
||||
@@ -2,6 +2,8 @@ import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from apps.stable_diffusion.src.utils.resamplers import resampler_list
|
||||
|
||||
|
||||
def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
@@ -132,6 +134,47 @@ p.add_argument(
|
||||
"img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_hiresfix",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Use Hires Fix to do higher resolution images, while trying to "
|
||||
"avoid the issues that come with it. This is accomplished by first "
|
||||
"generating an image using txt2img, then running it through img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_height",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The height of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_width",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The width of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_strength",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="The denoising strength to apply for the Hires Fix.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--resample_type",
|
||||
type=str,
|
||||
default="Nearest Neighbor",
|
||||
choices=resampler_list,
|
||||
help="The resample type to use when resizing an image before being run "
|
||||
"through stable diffusion.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Stable Diffusion Training Params
|
||||
##############################################################################
|
||||
@@ -202,28 +245,30 @@ p.add_argument(
|
||||
"--left",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend left for outpainting.",
|
||||
help="If extend left for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--right",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend right for outpainting.",
|
||||
help="If extend right for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--up",
|
||||
"--top",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend top for outpainting.",
|
||||
help="If extend top for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--down",
|
||||
"--bottom",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend bottom for outpainting.",
|
||||
help="If extend bottom for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -255,7 +300,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Imports the model from torch module to shark_module otherwise "
|
||||
"downloads the model from shark_tank.",
|
||||
@@ -278,7 +323,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=True,
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available.",
|
||||
)
|
||||
@@ -371,7 +416,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--use_stencil",
|
||||
choices=["canny", "openpose", "scribble"],
|
||||
choices=["canny", "openpose", "scribble", "zoedepth"],
|
||||
help="Enable the stencil feature.",
|
||||
)
|
||||
|
||||
@@ -407,6 +452,14 @@ p.add_argument(
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device_allocator_heap_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify heap key for device caching allocator."
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
@@ -520,9 +573,17 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_constant_folding",
|
||||
"--compile_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to toggle debug assert/verify flags for imported IR in the"
|
||||
"iree-compiler. Default to false.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_constant_folding",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Controls constant folding in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
@@ -574,6 +635,25 @@ p.add_argument(
|
||||
help="Flag for enabling rest API.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--api_accept_origin",
|
||||
action="append",
|
||||
type=str,
|
||||
help="An origin to be accepted by the REST api for Cross Origin"
|
||||
"Resource Sharing (CORS). Use multiple times for multiple origins, "
|
||||
'or use --api_accept_origin="*" to accept all origins. If no origins '
|
||||
"are set no CORS headers will be returned by the api. Use, for "
|
||||
"instance, if you need to access the REST api from Javascript running "
|
||||
"in a web browser.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling debugging log in WebUI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery",
|
||||
default=True,
|
||||
@@ -651,6 +731,18 @@ p.add_argument(
|
||||
help="Specifies whether the docuchat's web version is running or not.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# rocm Flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree_rocm_target_chip",
|
||||
type=str,
|
||||
default="",
|
||||
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
|
||||
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
if args.import_debug:
|
||||
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector
|
||||
from apps.stable_diffusion.src.utils.stencils.zoe import ZoeDetector
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch
|
||||
from apps.stable_diffusion.src.utils.stencils import (
|
||||
CannyDetector,
|
||||
OpenposeDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
|
||||
stencil = {}
|
||||
@@ -117,6 +118,9 @@ def controlnet_hint_conversion(
|
||||
case "scribble":
|
||||
print("Working with scribble")
|
||||
controlnet_hint = hint_scribble(image)
|
||||
case "zoedepth":
|
||||
print("Working with ZoeDepth")
|
||||
controlnet_hint = hint_zoedepth(image)
|
||||
case _:
|
||||
return None
|
||||
controlnet_hint = controlnet_hint_shaping(
|
||||
@@ -127,7 +131,7 @@ def controlnet_hint_conversion(
|
||||
|
||||
stencil_to_model_id_map = {
|
||||
"canny": "lllyasviel/control_v11p_sd15_canny",
|
||||
"depth": "lllyasviel/control_v11p_sd15_depth",
|
||||
"zoedepth": "lllyasviel/control_v11f1p_sd15_depth",
|
||||
"hed": "lllyasviel/sd-controlnet-hed",
|
||||
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
|
||||
"normal": "lllyasviel/control_v11p_sd15_normalbae",
|
||||
@@ -184,3 +188,16 @@ def hint_scribble(image: Image.Image):
|
||||
detected_map = np.zeros_like(input_image, dtype=np.uint8)
|
||||
detected_map[np.min(input_image, axis=2) < 127] = 255
|
||||
return detected_map
|
||||
|
||||
|
||||
# Stencil 4. Depth (Only Zoe Preprocessing)
|
||||
def hint_zoedepth(image: Image.Image):
|
||||
with torch.no_grad():
|
||||
input_image = np.array(image)
|
||||
|
||||
if not "depth" in stencil:
|
||||
stencil["depth"] = ZoeDetector()
|
||||
|
||||
detected_map = stencil["depth"](input_image)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
58
apps/stable_diffusion/src/utils/stencils/zoe/__init__.py
Normal file
58
apps/stable_diffusion/src/utils/stencils/zoe/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import requests
|
||||
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
remote_model_path = (
|
||||
"https://huggingface.co/lllyasviel/Annotators/resolve/main/ZoeD_M12_N.pt"
|
||||
)
|
||||
|
||||
|
||||
class ZoeDetector:
|
||||
def __init__(self):
|
||||
cwd = Path.cwd()
|
||||
ckpt_path = Path(cwd, "stencil_annotator")
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
modelpath = ckpt_path / "ZoeD_M12_N.pt"
|
||||
|
||||
with requests.get(remote_model_path, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(modelpath, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
model = torch.hub.load(
|
||||
"monorimet/ZoeDepth:torch_update",
|
||||
"ZoeD_N",
|
||||
pretrained=False,
|
||||
force_reload=False,
|
||||
)
|
||||
model.load_state_dict(
|
||||
torch.load(modelpath, map_location=model.device)["model"]
|
||||
)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
def __call__(self, input_image):
|
||||
assert input_image.ndim == 3
|
||||
image_depth = input_image
|
||||
with torch.no_grad():
|
||||
image_depth = torch.from_numpy(image_depth).float()
|
||||
image_depth = image_depth / 255.0
|
||||
image_depth = rearrange(image_depth, "h w c -> 1 c h w")
|
||||
depth = self.model.infer(image_depth)
|
||||
|
||||
depth = depth[0, 0].cpu().numpy()
|
||||
|
||||
vmin = np.percentile(depth, 2)
|
||||
vmax = np.percentile(depth, 85)
|
||||
|
||||
depth -= vmin
|
||||
depth /= vmax - vmin
|
||||
depth = 1.0 - depth
|
||||
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
return depth_image
|
||||
@@ -18,14 +18,14 @@ import tempfile
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
from shark.iree_utils.metal_utils import get_metal_target_triple
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc, get_iree_rocm_args
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
@@ -78,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), model_name, extra_args
|
||||
os.getcwd(), model_name, extra_args, debug=args.compile_debug
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -154,8 +154,8 @@ def compile_through_fx(
|
||||
f16_input_mask=f16_input_mask,
|
||||
debug=debug,
|
||||
model_name=extended_model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
|
||||
if use_tuned:
|
||||
if "vae" in extended_model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
@@ -168,6 +168,14 @@ def compile_through_fx(
|
||||
mlir_module, extended_model_name, base_model_id
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_dir):
|
||||
save_dir = ""
|
||||
|
||||
mlir_module = save_mlir(
|
||||
mlir_module,
|
||||
model_name=extended_model_name,
|
||||
dir=save_dir,
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device if device is None else device,
|
||||
@@ -179,17 +187,22 @@ def compile_through_fx(
|
||||
mlir_module,
|
||||
)
|
||||
|
||||
del mlir_module
|
||||
gc.collect()
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if args.device_allocator_heap_key:
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={args.device_allocator_heap_key}",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
@@ -464,18 +477,38 @@ def get_available_devices():
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
# for drivers with single devices
|
||||
# let the default device be selected without any indexing
|
||||
if len(device_list_dict) == 1:
|
||||
device_list.append(f"{device_name} => {driver_name}")
|
||||
else:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name}://{i}"
|
||||
)
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
vulkan_devices = get_devices_by_name("vulkan")
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
@@ -499,17 +532,16 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
if "rocm" in args.device:
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
print(iree_flags)
|
||||
if args.iree_constant_folding == False:
|
||||
iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
iree_flags.append(
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
)
|
||||
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"default_compilation_flags"
|
||||
@@ -572,7 +604,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
)
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=custom_weights,
|
||||
checkpoint_path_or_dict=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
num_in_channels=num_in_channels,
|
||||
@@ -779,11 +811,12 @@ def batch_seeds(
|
||||
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
|
||||
|
||||
if repeatable:
|
||||
# set seed for the rng based on what we have so far
|
||||
saved_random_state = random_getstate()
|
||||
if all(seed < 0 for seed in seeds):
|
||||
seeds[0] = sanitize_seed(seeds[0])
|
||||
seed_random(str(seeds))
|
||||
|
||||
# set seed for the rng based on what we have so far
|
||||
saved_random_state = random_getstate()
|
||||
seed_random(str([n for n in seeds if n > -1]))
|
||||
|
||||
# generate any seeds that are unspecified
|
||||
seeds = [sanitize_seed(seed) for seed in seeds]
|
||||
@@ -822,6 +855,8 @@ def clear_all():
|
||||
elif os.name == "unix":
|
||||
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
if args.local_tank_cache != "":
|
||||
shutil.rmtree(args.local_tank_cache)
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
@@ -867,6 +902,13 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if args.write_metadata_to_png:
|
||||
# Using a conditional expression caused problems, so setting a new
|
||||
# variable for now.
|
||||
if args.use_hiresfix:
|
||||
png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}"
|
||||
else:
|
||||
png_size_text = f"{args.width}x{args.height}"
|
||||
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}"
|
||||
@@ -875,7 +917,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
f"Sampler: {args.scheduler}, "
|
||||
f"CFG scale: {args.guidance_scale}, "
|
||||
f"Seed: {img_seed},"
|
||||
f"Size: {args.width}x{args.height}, "
|
||||
f"Size: {png_size_text}, "
|
||||
f"Model: {img_model}, "
|
||||
f"VAE: {img_vae}, "
|
||||
f"LoRA: {img_lora}",
|
||||
@@ -902,8 +944,10 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"HEIGHT": args.height,
|
||||
"WIDTH": args.width,
|
||||
"HEIGHT": args.height
|
||||
if not args.use_hiresfix
|
||||
else args.hiresfix_height,
|
||||
"WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
"VAE": img_vae,
|
||||
@@ -941,6 +985,10 @@ def get_generation_text_info(seeds, device):
|
||||
)
|
||||
text_output += (
|
||||
f"\nsize={args.height}x{args.width}, "
|
||||
if not args.use_hiresfix
|
||||
else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, "
|
||||
)
|
||||
text_output += (
|
||||
f"batch_count={args.batch_count}, "
|
||||
f"batch_size={args.batch_size}, "
|
||||
f"max_length={args.max_length}"
|
||||
|
||||
1
apps/stable_diffusion/web/api/__init__.py
Normal file
1
apps/stable_diffusion/web/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from apps.stable_diffusion.web.api.sdapi_v1 import sdapi
|
||||
579
apps/stable_diffusion/web/api/sdapi_v1.py
Normal file
579
apps/stable_diffusion/web/api/sdapi_v1.py
Normal file
@@ -0,0 +1,579 @@
|
||||
import os
|
||||
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel, Field, conlist, model_validator
|
||||
|
||||
from apps.stable_diffusion.web.api.utils import (
|
||||
frozen_args,
|
||||
sampler_aliases,
|
||||
encode_pil_to_base64,
|
||||
decode_base64_to_image,
|
||||
get_model_from_request,
|
||||
get_scheduler_from_request,
|
||||
get_lora_params,
|
||||
get_device,
|
||||
GenerationInputData,
|
||||
GenerationResponseData,
|
||||
)
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_files,
|
||||
get_custom_model_pathfile,
|
||||
predefined_models,
|
||||
predefined_paint_models,
|
||||
predefined_upscaler_models,
|
||||
scheduler_list,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import txt2img_inf
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import img2img_inf
|
||||
from apps.stable_diffusion.web.ui.inpaint_ui import inpaint_inf
|
||||
from apps.stable_diffusion.web.ui.outpaint_ui import outpaint_inf
|
||||
from apps.stable_diffusion.web.ui.upscaler_ui import upscaler_inf
|
||||
|
||||
sdapi = FastAPI()
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/sd-models (lists available models)
|
||||
class AppParam(str, Enum):
|
||||
txt2img = "txt2img"
|
||||
img2img = "img2img"
|
||||
inpaint = "inpaint"
|
||||
outpaint = "outpaint"
|
||||
upscaler = "upscaler"
|
||||
|
||||
|
||||
@sdapi.get(
|
||||
"/v1/sd-models",
|
||||
summary="lists available models",
|
||||
description=(
|
||||
"This is all the models that this server currently knows about.\n "
|
||||
"Models listed may still have a compilation and build pending that "
|
||||
"will be triggered the first time they are used."
|
||||
),
|
||||
)
|
||||
def sd_models_api(app: AppParam = frozen_args.app):
|
||||
match app:
|
||||
case "inpaint" | "outpaint":
|
||||
checkpoint_type = "inpainting"
|
||||
predefined = predefined_paint_models
|
||||
case "upscaler":
|
||||
checkpoint_type = "upscaler"
|
||||
predefined = predefined_upscaler_models
|
||||
case _:
|
||||
checkpoint_type = ""
|
||||
predefined = predefined_models
|
||||
|
||||
return [
|
||||
{
|
||||
"title": model_file,
|
||||
"model_name": model_file,
|
||||
"hash": None,
|
||||
"sha256": None,
|
||||
"filename": get_custom_model_pathfile(model_file),
|
||||
"config": None,
|
||||
}
|
||||
for model_file in get_custom_model_files(
|
||||
custom_checkpoint_type=checkpoint_type
|
||||
)
|
||||
] + [
|
||||
{
|
||||
"title": model,
|
||||
"model_name": model,
|
||||
"hash": None,
|
||||
"sha256": None,
|
||||
"filename": None,
|
||||
"config": None,
|
||||
}
|
||||
for model in predefined
|
||||
]
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/samplers (lists schedulers)
|
||||
@sdapi.get(
|
||||
"/v1/samplers",
|
||||
summary="lists available schedulers/samplers",
|
||||
description=(
|
||||
"These are all the Schedulers defined and available. Not "
|
||||
"every scheduler is compatible with all apis. Aliases are "
|
||||
"equivalent samplers in A1111 if they are known."
|
||||
),
|
||||
)
|
||||
def sd_samplers_api():
|
||||
reverse_sampler_aliases = defaultdict(list)
|
||||
for key, value in sampler_aliases.items():
|
||||
reverse_sampler_aliases[value].append(key)
|
||||
|
||||
return (
|
||||
{
|
||||
"name": scheduler,
|
||||
"aliases": reverse_sampler_aliases.get(scheduler, []),
|
||||
"options": {},
|
||||
}
|
||||
for scheduler in scheduler_list
|
||||
)
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/options (lists application level options)
|
||||
@sdapi.get(
|
||||
"/v1/options",
|
||||
summary="lists current settings of application level options",
|
||||
description=(
|
||||
"A subset of the command line arguments set at startup renamed "
|
||||
"to correspond to the A1111 naming. Only a small subset of A1111 "
|
||||
"options are returned."
|
||||
),
|
||||
)
|
||||
def options_api():
|
||||
# This is mostly just enough to support what Koboldcpp wants, with a
|
||||
# few other things that seemed obvious
|
||||
return {
|
||||
"samples_save": True,
|
||||
"samples_format": frozen_args.output_img_format,
|
||||
"sd_model_checkpoint": os.path.basename(frozen_args.ckpt_loc)
|
||||
if frozen_args.ckpt_loc
|
||||
else frozen_args.hf_model_id,
|
||||
"sd_lora": frozen_args.use_lora,
|
||||
"sd_vae": frozen_args.custom_vae or "Automatic",
|
||||
"enable_pnginfo": frozen_args.write_metadata_to_png,
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/cmd-flags (lists command line argument settings)
|
||||
@sdapi.get(
|
||||
"/v1/cmd-flags",
|
||||
summary="lists the command line arguments value that were set on startup.",
|
||||
)
|
||||
def cmd_flags_api():
|
||||
return vars(frozen_args)
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/txt2img (Text to image)
|
||||
class ModelOverrideSettings(BaseModel):
|
||||
sd_model_checkpoint: str = get_model_from_request(
|
||||
fallback_model="stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
|
||||
|
||||
class Txt2ImgInputData(GenerationInputData):
|
||||
enable_hr: bool = frozen_args.use_hiresfix
|
||||
hr_resize_y: int = Field(
|
||||
default=frozen_args.hiresfix_height, ge=128, le=768, multiple_of=8
|
||||
)
|
||||
hr_resize_x: int = Field(
|
||||
default=frozen_args.hiresfix_width, ge=128, le=768, multiple_of=8
|
||||
)
|
||||
override_settings: ModelOverrideSettings = None
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/txt2img",
|
||||
summary="Does text to image generation",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def txt2img_api(InputData: Txt2ImgInputData):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
fallback_model="stabilityai/stable-diffusion-2-1-base",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(
|
||||
InputData, "txt2img_hires" if InputData.enable_hr else "txt2img"
|
||||
)
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed},"
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}. "
|
||||
)
|
||||
|
||||
res = txt2img_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
use_hiresfix=InputData.enable_hr,
|
||||
hiresfix_height=InputData.hr_resize_y,
|
||||
hiresfix_width=InputData.hr_resize_x,
|
||||
hiresfix_strength=frozen_args.hiresfix_strength,
|
||||
resample_type=frozen_args.resample_type,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/img2img (Image to image)
|
||||
class StencilParam(str, Enum):
|
||||
canny = "canny"
|
||||
openpose = "openpose"
|
||||
scribble = "scribble"
|
||||
zoedepth = "zoedepth"
|
||||
|
||||
|
||||
class Img2ImgInputData(GenerationInputData):
|
||||
init_images: conlist(str, min_length=1, max_length=2)
|
||||
denoising_strength: float = frozen_args.strength
|
||||
use_stencil: StencilParam = frozen_args.use_stencil
|
||||
override_settings: ModelOverrideSettings = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_image_supplied_for_scribble_stencil(self) -> "Img2ImgInputData":
|
||||
if (
|
||||
self.use_stencil == StencilParam.scribble
|
||||
and len(self.init_images) < 2
|
||||
):
|
||||
raise ValueError(
|
||||
"a second image must be supplied for the controlnet:scribble stencil"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/img2img",
|
||||
summary="Does image to image generation",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def img2img_api(
|
||||
InputData: Img2ImgInputData,
|
||||
):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
fallback_model="stabilityai/stable-diffusion-2-1-base",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "img2img")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
mask_image = (
|
||||
decode_base64_to_image(InputData.init_images[1])
|
||||
if len(InputData.init_images) > 1
|
||||
else None
|
||||
)
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed}, "
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}."
|
||||
)
|
||||
|
||||
res = img2img_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
{"image": init_image, "mask": mask_image},
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
InputData.denoising_strength,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
use_stencil=InputData.use_stencil,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
resample_type=frozen_args.resample_type,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/inpaint (Inpainting)
|
||||
class PaintModelOverideSettings(BaseModel):
|
||||
sd_model_checkpoint: str = get_model_from_request(
|
||||
checkpoint_type="inpainting",
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
|
||||
|
||||
class InpaintInputData(GenerationInputData):
|
||||
image: str = Field(description="Base64 encoded input image")
|
||||
mask: str = Field(description="Base64 encoded mask image")
|
||||
is_full_res: bool = False # Is this setting backwards in the UI?
|
||||
full_res_padding: int = Field(default=32, ge=0, le=256, multiple_of=4)
|
||||
denoising_strength: float = frozen_args.strength
|
||||
use_stencil: StencilParam = frozen_args.use_stencil
|
||||
override_settings: PaintModelOverideSettings = None
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/inpaint",
|
||||
summary="Does inpainting generation on an image",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def inpaint_api(
|
||||
InputData: InpaintInputData,
|
||||
):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
checkpoint_type="inpainting",
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "inpaint")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.image)
|
||||
mask = decode_base64_to_image(InputData.mask)
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f'Negative Prompt: {InputData.negative_prompt}", '
|
||||
f'Seed: {InputData.seed}", '
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}."
|
||||
)
|
||||
|
||||
res = inpaint_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
{"image": init_image, "mask": mask},
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.is_full_res,
|
||||
InputData.full_res_padding,
|
||||
InputData.steps,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/outpaint (Outpainting)
|
||||
class DirectionParam(str, Enum):
|
||||
left = "left"
|
||||
right = "right"
|
||||
up = "up"
|
||||
down = "down"
|
||||
|
||||
|
||||
class OutpaintInputData(GenerationInputData):
|
||||
init_images: list[str]
|
||||
pixels: int = Field(
|
||||
default=frozen_args.pixels, ge=8, le=256, multiple_of=8
|
||||
)
|
||||
mask_blur: int = Field(default=frozen_args.mask_blur, ge=0, le=64)
|
||||
directions: set[DirectionParam] = [
|
||||
direction
|
||||
for direction in ["left", "right", "up", "down"]
|
||||
if vars(frozen_args)[direction]
|
||||
]
|
||||
noise_q: float = frozen_args.noise_q
|
||||
color_variation: float = frozen_args.color_variation
|
||||
override_settings: PaintModelOverideSettings = None
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/outpaint",
|
||||
summary="Does outpainting generation on an image",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def outpaint_api(
|
||||
InputData: OutpaintInputData,
|
||||
):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
checkpoint_type="inpainting",
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "outpaint")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed}, "
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}."
|
||||
)
|
||||
|
||||
res = outpaint_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
init_image,
|
||||
InputData.pixels,
|
||||
InputData.mask_blur,
|
||||
InputData.directions,
|
||||
InputData.noise_q,
|
||||
InputData.color_variation,
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
|
||||
|
||||
# Rest API: /sdapi/v1/upscaler (Upscaling)
|
||||
class UpscalerModelOverideSettings(BaseModel):
|
||||
sd_model_checkpoint: str = get_model_from_request(
|
||||
checkpoint_type="upscaler",
|
||||
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
|
||||
)
|
||||
|
||||
|
||||
class UpscalerInputData(GenerationInputData):
|
||||
init_images: list[str] = Field(
|
||||
description="Base64 encoded image to upscale"
|
||||
)
|
||||
noise_level: int = frozen_args.noise_level
|
||||
override_settings: UpscalerModelOverideSettings = None
|
||||
|
||||
|
||||
@sdapi.post(
|
||||
"/v1/upscaler",
|
||||
summary="Does image upscaling",
|
||||
response_model=GenerationResponseData,
|
||||
)
|
||||
def upscaler_api(
|
||||
InputData: UpscalerInputData,
|
||||
):
|
||||
model_id = get_model_from_request(
|
||||
InputData,
|
||||
checkpoint_type="upscaler",
|
||||
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "upscaler")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
f"Negative Prompt: {InputData.negative_prompt}, "
|
||||
f"Seed: {InputData.seed}, "
|
||||
f"Model: {model_id}, "
|
||||
f"Scheduler: {scheduler}."
|
||||
)
|
||||
|
||||
res = upscaler_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
init_image,
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.steps,
|
||||
InputData.noise_level,
|
||||
InputData.cfg_scale,
|
||||
InputData.seed,
|
||||
batch_count=InputData.n_iter,
|
||||
batch_size=1,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
custom_vae=frozen_args.custom_vae or "None",
|
||||
precision="fp16",
|
||||
device=get_device(frozen_args.device),
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Since we're not streaming we just want the last generator result
|
||||
for items_so_far in res:
|
||||
items = items_so_far
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(items[0]),
|
||||
"parameters": {},
|
||||
"info": items[1],
|
||||
}
|
||||
211
apps/stable_diffusion/web/api/utils.py
Normal file
211
apps/stable_diffusion/web/api/utils.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import base64
|
||||
import pickle
|
||||
|
||||
from argparse import Namespace
|
||||
from fastapi.exceptions import HTTPException
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
get_custom_model_files,
|
||||
predefined_models,
|
||||
predefined_paint_models,
|
||||
predefined_upscaler_models,
|
||||
scheduler_list,
|
||||
scheduler_list_cpu_only,
|
||||
)
|
||||
|
||||
|
||||
# Probably overly cautious, but try to ensure we only use the starting
|
||||
# args in each api call, as the code does `args.<whatever> = <changed_value>`
|
||||
# in lots of places and in testing, it seemed to me, these changes leaked
|
||||
# into subsequent api calls.
|
||||
|
||||
# Roundtripping through pickle for deepcopy, there is probably a better way
|
||||
frozen_args = Namespace(**(pickle.loads(pickle.dumps(vars(args)))))
|
||||
|
||||
# an attempt to map some of the A1111 sampler names to scheduler names
|
||||
# https://github.com/huggingface/diffusers/issues/4167 is where the
|
||||
# (not so obvious) ones come from
|
||||
sampler_aliases = {
|
||||
# a1111/onnx (these point to diffusers classes in A1111)
|
||||
"pndm": "PNDM",
|
||||
"heun": "HeunDiscrete",
|
||||
"ddim": "DDIM",
|
||||
"ddpm": "DDPM",
|
||||
"euler": "EulerDiscrete",
|
||||
"euler-ancestral": "EulerAncestralDiscrete",
|
||||
"dpm": "DPMSolverMultistep",
|
||||
# a1111/k_diffusion (the obvious ones)
|
||||
"Euler a": "EulerAncestralDiscrete",
|
||||
"Euler": "EulerDiscrete",
|
||||
"LMS": "LMSDiscrete",
|
||||
"Heun": "HeunDiscrete",
|
||||
# a1111/k_diffusion (not so obvious)
|
||||
"DPM++ 2M": "DPMSolverMultistep",
|
||||
"DPM++ 2M Karras": "DPMSolverMultistepKarras",
|
||||
"DPM++ 2M SDE": "DPMSolverMultistep++",
|
||||
"DPM++ 2M SDE Karras": "DPMSolverMultistepKarras++",
|
||||
"DPM2": "KDPM2Discrete",
|
||||
"DPM2 a": "KDPM2AncestralDiscrete",
|
||||
}
|
||||
|
||||
allowed_schedulers = {
|
||||
"txt2img": {
|
||||
"schedulers": scheduler_list,
|
||||
"fallback": "SharkEulerDiscrete",
|
||||
},
|
||||
"txt2img_hires": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "DEISMultistep",
|
||||
},
|
||||
"img2img": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "EulerDiscrete",
|
||||
},
|
||||
"inpaint": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "DDIM",
|
||||
},
|
||||
"outpaint": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "DDIM",
|
||||
},
|
||||
"upscaler": {
|
||||
"schedulers": scheduler_list_cpu_only,
|
||||
"fallback": "DDIM",
|
||||
},
|
||||
}
|
||||
|
||||
# base pydantic model for sd generation apis
|
||||
|
||||
|
||||
class GenerationInputData(BaseModel):
|
||||
prompt: str = ""
|
||||
negative_prompt: str = ""
|
||||
hf_model_id: str | None = None
|
||||
height: int = Field(
|
||||
default=frozen_args.height, ge=128, le=768, multiple_of=8
|
||||
)
|
||||
width: int = Field(
|
||||
default=frozen_args.width, ge=128, le=768, multiple_of=8
|
||||
)
|
||||
sampler_name: str = frozen_args.scheduler
|
||||
cfg_scale: float = Field(default=frozen_args.guidance_scale, ge=1)
|
||||
steps: int = Field(default=frozen_args.steps, ge=1, le=100)
|
||||
seed: int = frozen_args.seed
|
||||
n_iter: int = Field(default=frozen_args.batch_count)
|
||||
|
||||
|
||||
class GenerationResponseData(BaseModel):
|
||||
images: list[str] = Field(description="Generated images, Base64 encoded")
|
||||
properties: dict = {}
|
||||
info: str
|
||||
|
||||
|
||||
# image encoding/decoding
|
||||
|
||||
|
||||
def encode_pil_to_base64(images: list[Image.Image]):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if frozen_args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif frozen_args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding: str):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=400, detail="Invalid encoded image")
|
||||
|
||||
|
||||
# get valid sd models/vaes/schedulers etc.
|
||||
|
||||
|
||||
def get_predefined_models(custom_checkpoint_type: str):
|
||||
match custom_checkpoint_type:
|
||||
case "inpainting":
|
||||
return predefined_paint_models
|
||||
case "upscaler":
|
||||
return predefined_upscaler_models
|
||||
case _:
|
||||
return predefined_models
|
||||
|
||||
|
||||
def get_model_from_request(
|
||||
request_data=None,
|
||||
checkpoint_type: str = "",
|
||||
fallback_model: str = "",
|
||||
):
|
||||
model = None
|
||||
if request_data:
|
||||
if request_data.hf_model_id:
|
||||
model = request_data.hf_model_id
|
||||
elif request_data.override_settings:
|
||||
model = request_data.override_settings.sd_model_checkpoint
|
||||
|
||||
# if the request didn't specify a model try the command line args
|
||||
result = model or frozen_args.ckpt_loc or frozen_args.hf_model_id
|
||||
|
||||
# make sure whatever we have is a valid model for the checkpoint type
|
||||
if result in get_custom_model_files(
|
||||
custom_checkpoint_type=checkpoint_type
|
||||
) + get_predefined_models(checkpoint_type):
|
||||
return result
|
||||
# if not return what was specified as the fallback
|
||||
else:
|
||||
return fallback_model
|
||||
|
||||
|
||||
def get_scheduler_from_request(
|
||||
request_data: GenerationInputData, operation: str
|
||||
):
|
||||
allowed = allowed_schedulers[operation]
|
||||
|
||||
requested = request_data.sampler_name
|
||||
requested = sampler_aliases.get(requested, requested)
|
||||
|
||||
return (
|
||||
requested
|
||||
if requested in allowed["schedulers"]
|
||||
else allowed["fallback"]
|
||||
)
|
||||
|
||||
|
||||
def get_lora_params(use_lora: str):
|
||||
# TODO: since the inference functions in the webui, which we are
|
||||
# still calling into for the api, jam these back together again before
|
||||
# handing them off to the pipeline, we should remove this nonsense
|
||||
# and unify their selection in the UI and command line args proper
|
||||
if use_lora in get_custom_model_files("lora"):
|
||||
return (use_lora, "")
|
||||
|
||||
return ("None", use_lora)
|
||||
|
||||
|
||||
def get_device(device_str: str):
|
||||
# first substring match in the list available devices, with first
|
||||
# device when none are matched
|
||||
return next(
|
||||
(device for device in available_devices if device_str in device),
|
||||
available_devices[0],
|
||||
)
|
||||
@@ -1,6 +1,8 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
from multiprocessing import freeze_support
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import apps.stable_diffusion.web.utils.app as app
|
||||
|
||||
if sys.platform == "darwin":
|
||||
# import before IREE to avoid torch-MLIR library issues
|
||||
@@ -20,64 +22,54 @@ if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
|
||||
def launch_app(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
|
||||
window = Tk()
|
||||
|
||||
# get screen width and height of display and make it more reasonably
|
||||
# sized as we aren't making it full-screen or maximized
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.debug:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
if args.api or "api" in args.ui.split(","):
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_api,
|
||||
img2img_api,
|
||||
upscaler_api,
|
||||
inpaint_api,
|
||||
outpaint_api,
|
||||
llm_chat_api,
|
||||
)
|
||||
from apps.stable_diffusion.web.api import sdapi
|
||||
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
# init global sd pipeline and config
|
||||
global_obj._init()
|
||||
|
||||
app = FastAPI()
|
||||
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
|
||||
api = FastAPI()
|
||||
api.mount("/sdapi/", sdapi)
|
||||
|
||||
# chat APIs needed for compatibility with multiple extensions using OpenAI API
|
||||
app.add_api_route(
|
||||
api.add_api_route(
|
||||
"/v1/chat/completions", llm_chat_api, methods=["post"]
|
||||
)
|
||||
app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
|
||||
app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
|
||||
app.add_api_route("/completions", llm_chat_api, methods=["post"])
|
||||
app.add_api_route(
|
||||
api.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
|
||||
api.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
|
||||
api.add_api_route("/completions", llm_chat_api, methods=["post"])
|
||||
api.add_api_route(
|
||||
"/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
|
||||
)
|
||||
app.include_router(APIRouter())
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.server_port)
|
||||
api.include_router(APIRouter())
|
||||
|
||||
# deal with CORS requests if CORS accept origins are set
|
||||
if args.api_accept_origin:
|
||||
print(
|
||||
f"API Configured for CORS. Accepting origins: { args.api_accept_origin }"
|
||||
)
|
||||
api.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.api_accept_origin,
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
else:
|
||||
print("API not configured for CORS")
|
||||
|
||||
uvicorn.run(api, host="0.0.0.0", port=args.server_port)
|
||||
sys.exit(0)
|
||||
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
@@ -91,7 +83,10 @@ if __name__ == "__main__":
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
create_custom_models_folders,
|
||||
nodicon_loc,
|
||||
)
|
||||
|
||||
create_custom_models_folders()
|
||||
|
||||
@@ -107,7 +102,6 @@ if __name__ == "__main__":
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
@@ -119,7 +113,6 @@ if __name__ == "__main__":
|
||||
# h2ogpt_web,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
@@ -128,7 +121,6 @@ if __name__ == "__main__":
|
||||
img2img_sendto_upscaler,
|
||||
inpaint_web,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
@@ -137,7 +129,6 @@ if __name__ == "__main__":
|
||||
inpaint_sendto_upscaler,
|
||||
outpaint_web,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
@@ -146,16 +137,15 @@ if __name__ == "__main__":
|
||||
outpaint_sendto_upscaler,
|
||||
upscaler_web,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
lora_train_web,
|
||||
model_web,
|
||||
model_config_web,
|
||||
# lora_train_web,
|
||||
# model_web,
|
||||
# model_config_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
@@ -210,9 +200,18 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
|
||||
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
|
||||
) as sd_web:
|
||||
with gr.Tabs() as tabs:
|
||||
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
|
||||
# have a unique id that doesn't clash with any of the other tabs,
|
||||
# and that the order in the code here is the order they should
|
||||
# appear in the ui, as the id value doesn't determine the order.
|
||||
|
||||
# Where possible, avoid changing the id of any tab that is the
|
||||
# destination of one of the 'send to' buttons. If you do have to change
|
||||
# that id, make sure you update the relevant register_button_click calls
|
||||
# further down with the new id.
|
||||
with gr.TabItem(label="Text-to-Image", id=0):
|
||||
txt2img_web.render()
|
||||
with gr.TabItem(label="Image-to-Image", id=1):
|
||||
@@ -223,16 +222,6 @@ if __name__ == "__main__":
|
||||
outpaint_web.render()
|
||||
with gr.TabItem(label="Upscaler", id=4):
|
||||
upscaler_web.render()
|
||||
with gr.TabItem(label="Model Manager", id=6):
|
||||
model_web.render()
|
||||
with gr.TabItem(label="Chat Bot(Experimental)", id=7):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(label="Generate Sharding Config", id=8):
|
||||
model_config_web.render()
|
||||
with gr.TabItem(label="LoRA Training(Experimental)", id=9):
|
||||
lora_train_web.render()
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
minigpt4_web.render()
|
||||
if args.output_gallery:
|
||||
with gr.TabItem(label="Output Gallery", id=5) as og_tab:
|
||||
outputgallery_web.render()
|
||||
@@ -248,10 +237,31 @@ if __name__ == "__main__":
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
# with gr.TabItem(label="Model Manager", id=6):
|
||||
# model_web.render()
|
||||
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
# lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=8):
|
||||
stablelm_chat.render()
|
||||
# with gr.TabItem(
|
||||
# label="Generate Sharding Config (Experimental)", id=9
|
||||
# ):
|
||||
# model_config_web.render()
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
# h2ogpt_upload.render()
|
||||
# h2ogpt_upload.render()
|
||||
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
|
||||
# h2ogpt_web.render()
|
||||
# h2ogpt_web.render()
|
||||
|
||||
actual_port = app.usable_port()
|
||||
if actual_port != args.server_port:
|
||||
sd_web.load(
|
||||
fn=lambda: gr.Info(
|
||||
f"Port {args.server_port} is in use by another application. "
|
||||
f"Shark is running on port {actual_port} instead."
|
||||
)
|
||||
)
|
||||
|
||||
# send to buttons
|
||||
register_button_click(
|
||||
@@ -385,42 +395,38 @@ if __name__ == "__main__":
|
||||
modelmanager_sendto_txt2img,
|
||||
0,
|
||||
[hf_models],
|
||||
[txt2img_custom_model, txt2img_hf_model_id, tabs],
|
||||
[txt2img_custom_model, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_img2img,
|
||||
1,
|
||||
[hf_models],
|
||||
[img2img_custom_model, img2img_hf_model_id, tabs],
|
||||
[img2img_custom_model, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_inpaint,
|
||||
2,
|
||||
[hf_models],
|
||||
[inpaint_custom_model, inpaint_hf_model_id, tabs],
|
||||
[inpaint_custom_model, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_outpaint,
|
||||
3,
|
||||
[hf_models],
|
||||
[outpaint_custom_model, outpaint_hf_model_id, tabs],
|
||||
[outpaint_custom_model, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_upscaler,
|
||||
4,
|
||||
[hf_models],
|
||||
[upscaler_custom_model, upscaler_hf_model_id, tabs],
|
||||
[upscaler_custom_model, tabs],
|
||||
)
|
||||
|
||||
sd_web.queue()
|
||||
if args.ui == "app":
|
||||
t = Process(
|
||||
target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
)
|
||||
t.start()
|
||||
sd_web.launch(
|
||||
share=args.share,
|
||||
inbrowser=args.ui == "web",
|
||||
inbrowser=not app.launch(actual_port),
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.server_port,
|
||||
server_port=actual_port,
|
||||
favicon_path=nodicon_loc,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
txt2img_inf,
|
||||
txt2img_api,
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
@@ -14,10 +12,8 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_inf,
|
||||
img2img_api,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
@@ -27,10 +23,8 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
inpaint_inf,
|
||||
inpaint_api,
|
||||
inpaint_web,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
@@ -40,10 +34,8 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
outpaint_inf,
|
||||
outpaint_api,
|
||||
outpaint_web,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
@@ -53,10 +45,8 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.upscaler_ui import (
|
||||
upscaler_inf,
|
||||
upscaler_api,
|
||||
upscaler_web,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
|
||||
@@ -37,8 +37,15 @@ start_message = """
|
||||
|
||||
def create_prompt(history):
|
||||
system_message = start_message
|
||||
for item in history:
|
||||
print("His item: ", item)
|
||||
|
||||
conversation = "".join(["".join([item[0], item[1]]) for item in history])
|
||||
conversation = "<|endoftext|>".join(
|
||||
[
|
||||
"<|endoftext|><|answer|>".join([item[0], item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
@@ -48,10 +55,12 @@ def create_prompt(history):
|
||||
def chat(curr_system_message, history, device, precision):
|
||||
args.run_docuchat_web = True
|
||||
global h2ogpt_model
|
||||
global sharkModel
|
||||
global h2ogpt_tokenizer
|
||||
global model_state
|
||||
global langchain
|
||||
global userpath_selector
|
||||
from apps.language_models.langchain.h2oai_pipeline import generate_token
|
||||
|
||||
if h2ogpt_model == 0:
|
||||
if "cuda" in device:
|
||||
@@ -106,9 +115,14 @@ def chat(curr_system_message, history, device, precision):
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
)
|
||||
from apps.language_models.langchain.h2oai_pipeline import (
|
||||
H2OGPTSHARKModel,
|
||||
)
|
||||
|
||||
sharkModel = H2OGPTSHARKModel()
|
||||
|
||||
prompt = create_prompt(history)
|
||||
output = langchain.evaluate(
|
||||
output_dict = langchain.evaluate(
|
||||
model_state=model_state,
|
||||
my_db_state=None,
|
||||
instruction=prompt,
|
||||
@@ -168,7 +182,11 @@ def chat(curr_system_message, history, device, precision):
|
||||
model_lock=True,
|
||||
user_path=userpath_selector.value,
|
||||
)
|
||||
history[-1][1] = output["response"]
|
||||
|
||||
output = generate_token(sharkModel, **output_dict)
|
||||
for partial_text in output:
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
return history
|
||||
|
||||
|
||||
@@ -194,6 +212,7 @@ with gr.Blocks(title="DocuChat") as h2ogpt_web:
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
|
||||
@@ -3,10 +3,8 @@ import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
import PIL
|
||||
from math import ceil
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -29,6 +27,7 @@ from apps.stable_diffusion.src import (
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
resampler_list,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
import numpy as np
|
||||
@@ -54,8 +53,7 @@ def img2img_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -67,6 +65,7 @@ def img2img_inf(
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -101,21 +100,17 @@ def img2img_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files():
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -245,7 +240,7 @@ def img2img_inf(
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
ceil(steps / strength),
|
||||
strength,
|
||||
guidance_scale,
|
||||
seeds[current_batch],
|
||||
@@ -255,6 +250,7 @@ def img2img_inf(
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(
|
||||
@@ -279,87 +275,6 @@ def img2img_inf(
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Img2Img Rest API.
|
||||
def img2img_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}.'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = img2img_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
init_image,
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["denoising_strength"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
use_stencil=InputData["use_stencil"]
|
||||
if "use_stencil" in InputData.keys()
|
||||
else "None",
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -378,31 +293,19 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
i2i_model_info = (str(get_custom_model_path())).replace(
|
||||
"\\", "\n\\"
|
||||
i2i_model_info = (
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
)
|
||||
i2i_model_info = f"Custom Model Path: {i2i_model_info}"
|
||||
img2img_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=i2i_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
img2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: SG161222/Realistic_Vision_V1.3, "
|
||||
"https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
choices=get_custom_model_files() + predefined_models,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
i2i_vae_info = (str(get_custom_model_path("vae"))).replace(
|
||||
@@ -417,6 +320,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -432,7 +337,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
# TODO: make this import image prompt info if it exists
|
||||
img2img_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
source="upload",
|
||||
@@ -447,7 +352,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
elem_id="stencil_model",
|
||||
label="Stencil model",
|
||||
value="None",
|
||||
choices=["None", "canny", "openpose", "scribble"],
|
||||
choices=[
|
||||
"None",
|
||||
"canny",
|
||||
"openpose",
|
||||
"scribble",
|
||||
"zoedepth",
|
||||
],
|
||||
)
|
||||
|
||||
def show_canvas(choice):
|
||||
@@ -508,6 +419,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
).replace("\\", "\n\\")
|
||||
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=i2i_lora_info,
|
||||
elem_id="lora_weights",
|
||||
@@ -531,6 +443,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -550,15 +463,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
width = gr.Slider(
|
||||
384, 768, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
@@ -581,11 +485,26 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
step=0.01,
|
||||
label="Denoising Strength",
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
@@ -629,17 +548,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -651,13 +561,26 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{i2i_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
img2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
img2img_sendto_outpaint = gr.Button(
|
||||
@@ -683,7 +606,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -695,6 +617,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output, img2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
|
||||
@@ -4,9 +4,6 @@ import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -53,8 +50,7 @@ def inpaint_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -89,21 +85,17 @@ def inpaint_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"):
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -228,86 +220,6 @@ def inpaint_inf(
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Inpaint Rest API.
|
||||
def inpaint_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}.'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["image"])
|
||||
mask = decode_base64_to_image(InputData["mask"])
|
||||
res = inpaint_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
{"image": init_image, "mask": mask},
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["is_full_res"],
|
||||
InputData["full_res_padding"],
|
||||
InputData["steps"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -327,34 +239,21 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
inpaint_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
inpaint_model_info = (
|
||||
f"Custom Model Path: {inpaint_model_info}"
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
)
|
||||
inpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=inpaint_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files(
|
||||
choices=get_custom_model_files(
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
inpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
|
||||
"https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
inpaint_vae_info = (
|
||||
@@ -369,6 +268,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -406,6 +307,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -424,6 +326,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -527,17 +430,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -549,14 +443,26 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{inpaint_model_info}\n"
|
||||
"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
inpaint_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
inpaint_sendto_outpaint = gr.Button(
|
||||
@@ -583,7 +489,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
inpaint_custom_model,
|
||||
inpaint_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
|
||||
BIN
apps/stable_diffusion/web/ui/logos/nod-icon.png
Normal file
BIN
apps/stable_diffusion/web/ui/logos/nod-icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
@@ -50,6 +50,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
@@ -73,6 +74,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -105,6 +107,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
@@ -177,6 +180,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
|
||||
@@ -109,7 +109,7 @@ with gr.Blocks() as minigpt4_web:
|
||||
gr.Markdown(description)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
with gr.Column():
|
||||
image = gr.Image(type="pil")
|
||||
upload_button = gr.Button(
|
||||
value="Upload & Start Chat",
|
||||
@@ -143,6 +143,7 @@ with gr.Blocks() as minigpt4_web:
|
||||
# else "Only CUDA Supported for now",
|
||||
choices=["cuda"],
|
||||
interactive=False,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
|
||||
@@ -98,6 +98,7 @@ with gr.Blocks() as model_web:
|
||||
choices=None,
|
||||
value=None,
|
||||
visible=False,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
# TODO: select and SendTo
|
||||
civit_models = gr.Gallery(
|
||||
|
||||
@@ -53,8 +53,7 @@ def outpaint_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -88,21 +87,17 @@ def outpaint_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"):
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -233,88 +228,6 @@ def outpaint_inf(
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Inpaint Rest API.
|
||||
def outpaint_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = outpaint_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
init_image,
|
||||
InputData["pixels"],
|
||||
InputData["mask_blur"],
|
||||
InputData["directions"],
|
||||
InputData["noise_q"],
|
||||
InputData["color_variation"],
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -332,36 +245,22 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
outpaint_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
outpaint_model_info = (
|
||||
f"Custom Model Path: {outpaint_model_info}"
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
)
|
||||
outpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=outpaint_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files(
|
||||
choices=get_custom_model_files(
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
outpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
|
||||
"https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
outpaint_vae_info = (
|
||||
@@ -376,8 +275,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
@@ -411,6 +311,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -429,6 +330,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -555,17 +457,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -577,13 +470,26 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{outpaint_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
outpaint_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -611,7 +517,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
outpaint_custom_model,
|
||||
outpaint_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
|
||||
@@ -109,6 +109,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
value="",
|
||||
interactive=True,
|
||||
elem_classes="dropdown_no_container",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column(
|
||||
scale=1,
|
||||
|
||||
@@ -8,6 +8,7 @@ from transformers import (
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def user(message, history):
|
||||
@@ -23,88 +24,81 @@ past_key_values = None
|
||||
|
||||
model_map = {
|
||||
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"llama2_13b": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"codegen": "Salesforce/codegen25-7b-multi",
|
||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"vicuna4": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
|
||||
}
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2_7b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_13b": (
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"StableLM": (
|
||||
"<|SYSTEM|># StableLM Tuned (Alpha version)"
|
||||
"\n- StableLM is a helpful and harmless open-source AI language model "
|
||||
"developed by StabilityAI."
|
||||
"\n- StableLM is excited to be able to help the user, but will refuse "
|
||||
"to do anything that could be considered harmful to the user."
|
||||
"\n- StableLM is more than just an information source, StableLM is also "
|
||||
"able to write poetry, short stories, and make jokes."
|
||||
"\n- StableLM will refuse to participate in anything that "
|
||||
"could harm a human."
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
"A chat between a curious user and an artificial intelligence "
|
||||
"assistant. The assistant gives helpful, detailed, and "
|
||||
"polite answers to the user's questions.\n"
|
||||
),
|
||||
"vicuna4": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna1p3": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"codegen": "",
|
||||
}
|
||||
|
||||
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
def create_prompt(model_name, history, prompt_prefix):
|
||||
system_message = ""
|
||||
if prompt_prefix:
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if model_name in [
|
||||
"StableLM",
|
||||
"vicuna",
|
||||
"vicuna4",
|
||||
"vicuna1p3",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
if "llama2" in model_name:
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
conversation = "".join(
|
||||
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
|
||||
)
|
||||
if prompt_prefix:
|
||||
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
else:
|
||||
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
elif model_name in ["vicuna"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
else:
|
||||
conversation = "".join(
|
||||
["".join([item[0], item[1]]) for item in history]
|
||||
)
|
||||
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
@@ -113,124 +107,193 @@ def set_vicuna_model(model):
|
||||
vicuna_model = model
|
||||
|
||||
|
||||
def get_default_config():
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
|
||||
[1, 19]
|
||||
)
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
CombinedModel,
|
||||
)
|
||||
from shark.shark_generate_model_config import GenerateConfigFile
|
||||
|
||||
model = CombinedModel()
|
||||
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
|
||||
c.split_into_layers()
|
||||
|
||||
|
||||
def clean_device_info(raw_device):
|
||||
# return appropriate device and device_id for consumption by LLM pipeline
|
||||
# Multiple devices only supported for vulkan and rocm (as of now).
|
||||
# default device must be selected for all others
|
||||
|
||||
device_id = None
|
||||
device = (
|
||||
raw_device
|
||||
if "=>" not in raw_device
|
||||
else raw_device.split("=>")[1].strip()
|
||||
)
|
||||
if "://" in device:
|
||||
device, device_id = device.split("://")
|
||||
device_id = int(device_id) # using device index in webui
|
||||
|
||||
if device not in ["rocm", "vulkan"]:
|
||||
device_id = None
|
||||
|
||||
return device, device_id
|
||||
|
||||
|
||||
model_vmfb_key = ""
|
||||
|
||||
|
||||
# TODO: Make chat reusable for UI and API
|
||||
def chat(
|
||||
curr_system_message,
|
||||
prompt_prefix,
|
||||
history,
|
||||
model,
|
||||
devices,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
cli=True,
|
||||
cli=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global past_key_values
|
||||
|
||||
global model_vmfb_key
|
||||
global vicuna_model
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
|
||||
if model_name in [
|
||||
"vicuna",
|
||||
"vicuna4",
|
||||
"vicuna1p3",
|
||||
"codegen",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
device, device_id = clean_device_info(device)
|
||||
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
|
||||
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
if device == "vulkan":
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if vulkan_target_triple == "":
|
||||
# We already have the device_id extracted via WebUI, so we directly use
|
||||
# that to find the target triple.
|
||||
vulkan_target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[device_id]
|
||||
)
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={vulkan_target_triple}"
|
||||
)
|
||||
if "rdna" in vulkan_target_triple:
|
||||
flags_to_add = [
|
||||
"--iree-spirv-index-bits=64",
|
||||
]
|
||||
_extra_args = _extra_args + flags_to_add
|
||||
|
||||
if device_id is None:
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[id]
|
||||
)
|
||||
if target_triple == vulkan_target_triple:
|
||||
device_id = id
|
||||
break
|
||||
id += 1
|
||||
|
||||
assert (
|
||||
device_id
|
||||
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
print(f"Will use vulkan target triple : {vulkan_target_triple}")
|
||||
|
||||
elif "rocm" in device:
|
||||
# add iree rocm flags
|
||||
if args.iree_rocm_target_chip != "":
|
||||
_extra_args.append(
|
||||
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
|
||||
)
|
||||
print(f"extra args = {_extra_args}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
ShardedVicuna as Vicuna,
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna as Vicuna,
|
||||
# if config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=download_vmfb,
|
||||
load_mlir_from_shark_tank=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
device_id=device_id,
|
||||
)
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if vicuna_model == 0:
|
||||
device = devices[0]
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
if vicuna_model is None:
|
||||
sys.exit("Unable to instantiate the model object, exiting.")
|
||||
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = Vicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
)
|
||||
else:
|
||||
if len(devices) == 1 and config_file is None:
|
||||
vicuna_model = Vicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
)
|
||||
else:
|
||||
if config_file is not None:
|
||||
config_file = open(config_file)
|
||||
config_json = json.load(config_file)
|
||||
config_file.close()
|
||||
else:
|
||||
config_json = None
|
||||
vicuna_model = Vicuna(
|
||||
model_name,
|
||||
device=device,
|
||||
precision=precision,
|
||||
config_json=config_json,
|
||||
)
|
||||
|
||||
prompt = create_prompt(model_name, history)
|
||||
|
||||
for partial_text in vicuna_model.generate(prompt, cli=cli):
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
|
||||
return history
|
||||
|
||||
# else Model is StableLM
|
||||
global sharkModel
|
||||
from apps.language_models.src.pipelines.stablelm_pipeline import (
|
||||
SharkStableLM,
|
||||
)
|
||||
|
||||
if sharkModel == 0:
|
||||
# max_new_tokens=512
|
||||
shark_slm = SharkStableLM(
|
||||
model_name
|
||||
) # pass elements from UI as required
|
||||
|
||||
# Construct the input message string for the model by concatenating the
|
||||
# current system message and conversation history
|
||||
if len(curr_system_message.split()) > 160:
|
||||
print("clearing context")
|
||||
prompt = create_prompt(model_name, history)
|
||||
generate_kwargs = dict(prompt=prompt)
|
||||
|
||||
words_list = shark_slm.generate(**generate_kwargs)
|
||||
prompt = create_prompt(model_name, history, prompt_prefix)
|
||||
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
print(new_text)
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to clean up the message textbox and the updated
|
||||
# conversation history
|
||||
yield history
|
||||
return words_list
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, f"Prefill: {prefill_time:.2f}"
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
|
||||
return history, ""
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
@@ -266,17 +329,9 @@ def llm_chat_api(InputData: dict):
|
||||
UnshardedVicuna,
|
||||
)
|
||||
|
||||
device_id = None
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
device, device_id = clean_device_info(device)
|
||||
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
@@ -284,6 +339,9 @@ def llm_chat_api(InputData: dict):
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=True,
|
||||
load_mlir_from_shark_tank=True,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
# TODO: add role dict for different models
|
||||
@@ -348,38 +406,53 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
label="Select Model",
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
supported_devices = available_devices
|
||||
enabled = len(supported_devices) > 0
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
print(supported_devices)
|
||||
devices = gr.Dropdown(
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
multiselect=True,
|
||||
allow_custom_value=True,
|
||||
# multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp16",
|
||||
value="int4",
|
||||
choices=[
|
||||
"int4",
|
||||
"int8",
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
prompt_prefix = gr.Checkbox(
|
||||
label="Add System Prompt",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(label="Upload sharding configuration")
|
||||
json_view_button = gr.Button("View as JSON")
|
||||
json_view = gr.JSON()
|
||||
config_file = gr.File(
|
||||
label="Upload sharding configuration", visible=False
|
||||
)
|
||||
json_view_button = gr.Button(label="View as JSON", visible=False)
|
||||
json_view = gr.JSON(interactive=True, visible=False)
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
@@ -398,24 +471,47 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
system_msg = gr.Textbox(
|
||||
start_message, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot],
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot],
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
|
||||
@@ -4,15 +4,14 @@ import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from math import ceil
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
scheduler_list_cpu_only,
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
@@ -26,10 +25,12 @@ from apps.stable_diffusion.src import (
|
||||
utils,
|
||||
save_output_img,
|
||||
prompt_examples,
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
@@ -50,8 +51,7 @@ def txt2img_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -62,6 +62,11 @@ def txt2img_inf(
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
use_hiresfix: bool,
|
||||
hiresfix_height: int,
|
||||
hiresfix_width: int,
|
||||
hiresfix_strength: float,
|
||||
resample_type: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -84,21 +89,17 @@ def txt2img_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files():
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -138,6 +139,11 @@ def txt2img_inf(
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.use_hiresfix = use_hiresfix
|
||||
args.hiresfix_height = hiresfix_height
|
||||
args.hiresfix_width = hiresfix_width
|
||||
args.hiresfix_strength = hiresfix_strength
|
||||
args.resample_type = resample_type
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.iree_metal_target_platform = init_iree_metal_target_platform
|
||||
@@ -200,6 +206,81 @@ def txt2img_inf(
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
# TODO: allow user to save original image
|
||||
# TODO: add option to let user keep both pipelines loaded, and unload
|
||||
# either at will
|
||||
# TODO: add custom step value slider
|
||||
# TODO: add option to use secondary model for the img2img pass
|
||||
if use_hiresfix is True:
|
||||
new_config_obj = Config(
|
||||
"img2img",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
precision,
|
||||
1,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil="None",
|
||||
ondemand=ondemand,
|
||||
)
|
||||
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(args.scheduler)
|
||||
|
||||
global_obj.set_sd_obj(
|
||||
Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
1,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(args.scheduler)
|
||||
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
out_imgs[0],
|
||||
batch_size,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
ceil(steps / hiresfix_strength),
|
||||
hiresfix_strength,
|
||||
guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil="None",
|
||||
resample_type=resample_type,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
@@ -219,70 +300,6 @@ def txt2img_inf(
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Text2Img Rest API.
|
||||
def txt2img_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}.'
|
||||
)
|
||||
res = txt2img_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -302,32 +319,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
t2i_model_info = (
|
||||
f"Custom Model Path: {t2i_model_info}"
|
||||
)
|
||||
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
txt2img_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=t2i_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
choices=get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
txt2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the dropdown "
|
||||
"on the left and enter model ID here.",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL.",
|
||||
lines=3,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
t2i_vae_info = (
|
||||
@@ -343,6 +346,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
txt2img_png_info_img = gr.Image(
|
||||
@@ -379,6 +384,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -397,6 +403,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -483,6 +490,41 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Accordion(label="Hires Fix Options", open=False):
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
use_hiresfix = gr.Checkbox(
|
||||
value=args.use_hiresfix,
|
||||
label="Use Hires Fix",
|
||||
interactive=True,
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=False,
|
||||
)
|
||||
hiresfix_height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_height,
|
||||
step=8,
|
||||
label="Hires Fix Height",
|
||||
)
|
||||
hiresfix_width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_width,
|
||||
step=8,
|
||||
label="Hires Fix Width",
|
||||
)
|
||||
hiresfix_strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.hiresfix_strength,
|
||||
step=0.01,
|
||||
label="Hires Fix Denoising Strength",
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
@@ -494,17 +536,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
@@ -523,13 +556,26 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{t2i_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
txt2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -554,7 +600,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
@@ -565,6 +610,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
use_hiresfix,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
hiresfix_strength,
|
||||
resample_type,
|
||||
],
|
||||
outputs=[txt2img_gallery, std_output, txt2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
@@ -599,7 +649,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
width,
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
@@ -615,9 +664,28 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
width,
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
|
||||
# SharkEulerDiscrete doesn't work with img2img which hires_fix uses
|
||||
def set_compatible_schedulers(hires_fix_selected):
|
||||
if hires_fix_selected:
|
||||
return gr.Dropdown.update(
|
||||
choices=scheduler_list_cpu_only,
|
||||
value="DEISMultistep",
|
||||
)
|
||||
else:
|
||||
return gr.Dropdown.update(
|
||||
choices=scheduler_list,
|
||||
value="SharkEulerDiscrete",
|
||||
)
|
||||
|
||||
use_hiresfix.change(
|
||||
fn=set_compatible_schedulers,
|
||||
inputs=[use_hiresfix],
|
||||
outputs=[scheduler],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
@@ -3,9 +3,6 @@ import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -46,8 +43,7 @@ def upscaler_inf(
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
@@ -85,21 +81,17 @@ def upscaler_inf(
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
else:
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files(custom_checkpoint_type="upscaler"):
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
@@ -252,83 +244,6 @@ def upscaler_inf(
|
||||
yield generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
print(err)
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
encoded_imgs = []
|
||||
for image in images:
|
||||
with BytesIO() as output_bytes:
|
||||
if args.output_img_format.lower() == "png":
|
||||
image.save(output_bytes, format="PNG")
|
||||
|
||||
elif args.output_img_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Invalid image format"
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
encoded_imgs.append(base64.b64encode(bytes_data))
|
||||
return encoded_imgs
|
||||
|
||||
|
||||
# Upscaler Rest API.
|
||||
def upscaler_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = upscaler_inf(
|
||||
InputData["prompt"],
|
||||
InputData["negative_prompt"],
|
||||
init_image,
|
||||
InputData["height"],
|
||||
InputData["width"],
|
||||
InputData["steps"],
|
||||
InputData["noise_level"],
|
||||
InputData["cfg_scale"],
|
||||
InputData["seed"],
|
||||
batch_count=1,
|
||||
batch_size=1,
|
||||
scheduler="EulerDiscrete",
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
max_length=64,
|
||||
save_metadata_to_json=False,
|
||||
save_metadata_to_png=False,
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
"info": res[1],
|
||||
}
|
||||
|
||||
|
||||
with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -346,36 +261,22 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
upscaler_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
upscaler_model_info = (
|
||||
f"Custom Model Path: {upscaler_model_info}"
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
)
|
||||
upscaler_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info=upscaler_model_info,
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-x4-upscaler",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files(
|
||||
choices=get_custom_model_files(
|
||||
custom_checkpoint_type="upscaler"
|
||||
)
|
||||
+ predefined_upscaler_models,
|
||||
)
|
||||
upscaler_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: SG161222/Realistic_Vision_V1.3, "
|
||||
"https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
upscaler_vae_info = (
|
||||
@@ -390,6 +291,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
@@ -425,6 +328,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -443,6 +347,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Scheduler",
|
||||
value="DDIM",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -547,17 +452,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -569,14 +465,26 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at "
|
||||
value=f"{upscaler_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
upscaler_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -600,7 +508,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
batch_size,
|
||||
scheduler,
|
||||
upscaler_custom_model,
|
||||
upscaler_hf_model_id,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
|
||||
@@ -25,7 +25,7 @@ class Config:
|
||||
device: str
|
||||
use_lora: str
|
||||
use_stencil: str
|
||||
ondemand: str
|
||||
ondemand: str # should this be expecting a bool instead?
|
||||
|
||||
|
||||
custom_model_filetypes = (
|
||||
@@ -170,4 +170,5 @@ def cancel_sd():
|
||||
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
nodicon_loc = resource_path("logos/nod-icon.png")
|
||||
available_devices = get_available_devices()
|
||||
|
||||
105
apps/stable_diffusion/web/utils/app.py
Normal file
105
apps/stable_diffusion/web/utils/app.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import os
|
||||
import sys
|
||||
import webview
|
||||
import webview.util
|
||||
import socket
|
||||
|
||||
from contextlib import closing
|
||||
from multiprocessing import Process
|
||||
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
|
||||
def webview2_installed():
|
||||
if sys.platform != "win32":
|
||||
return False
|
||||
|
||||
# On windows we want to ensure we have MS webview2 available so we don't fall back
|
||||
# to MSHTML (aka ye olde Internet Explorer) which is deprecated by pywebview, and
|
||||
# apparently causes SHARK not to load in properly.
|
||||
|
||||
# Checking these registry entries is how Microsoft says to detect a webview2 installation:
|
||||
# https://learn.microsoft.com/en-us/microsoft-edge/webview2/concepts/distribution
|
||||
import winreg
|
||||
|
||||
path = r"SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
|
||||
|
||||
# only way can find if a registry entry even exists is to try and open it
|
||||
try:
|
||||
# check for an all user install
|
||||
with winreg.OpenKey(
|
||||
winreg.HKEY_LOCAL_MACHINE,
|
||||
path,
|
||||
0,
|
||||
winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY,
|
||||
) as registry_key:
|
||||
value, type = winreg.QueryValueEx(registry_key, "pv")
|
||||
|
||||
# if it didn't exist, we want to continue on...
|
||||
except WindowsError:
|
||||
try:
|
||||
# ...to check for a current user install
|
||||
with winreg.OpenKey(
|
||||
winreg.HKEY_CURRENT_USER,
|
||||
path,
|
||||
0,
|
||||
winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY,
|
||||
) as registry_key:
|
||||
value, type = winreg.QueryValueEx(registry_key, "pv")
|
||||
except WindowsError:
|
||||
value = None
|
||||
finally:
|
||||
return (value is not None) and value != "" and value != "0.0.0.0"
|
||||
|
||||
|
||||
def window(address):
|
||||
from tkinter import Tk
|
||||
|
||||
window = Tk()
|
||||
|
||||
# get screen width and height of display and make it more reasonably
|
||||
# sized as we aren't making it full-screen or maximized
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
def usable_port():
|
||||
# Make sure we can actually use the port given in args.server_port. If
|
||||
# not ask the OS for a port and return that as our port to use.
|
||||
|
||||
port = args.server_port
|
||||
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||
try:
|
||||
sock.bind(("0.0.0.0", port))
|
||||
except OSError:
|
||||
with closing(
|
||||
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
) as sock:
|
||||
sock.bind(("0.0.0.0", 0))
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
return sock.getsockname()[1]
|
||||
|
||||
return port
|
||||
|
||||
|
||||
def launch(port):
|
||||
# setup to launch as an app if app mode has been requested and we're able
|
||||
# to do it, answering whether we succeeded.
|
||||
if args.ui == "app" and (sys.platform != "win32" or webview2_installed()):
|
||||
try:
|
||||
t = Process(target=window, args=[f"http://localhost:{port}"])
|
||||
t.start()
|
||||
return True
|
||||
except webview.util.WebViewException:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
@@ -149,7 +149,6 @@ def import_png_metadata(
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
@@ -175,10 +174,8 @@ def import_png_metadata(
|
||||
|
||||
if "Model" in metadata and png_custom_model:
|
||||
custom_model = png_custom_model
|
||||
hf_model_id = ""
|
||||
if "Model" in metadata and png_hf_model_id:
|
||||
custom_model = "None"
|
||||
hf_model_id = png_hf_model_id
|
||||
elif "Model" in metadata and png_hf_model_id:
|
||||
custom_model = png_hf_model_id
|
||||
|
||||
if "LoRA" in metadata and lora_custom_model:
|
||||
custom_lora = lora_custom_model
|
||||
@@ -217,7 +214,6 @@ def import_png_metadata(
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
|
||||
@@ -129,12 +129,12 @@ pytest_benchmark_param = pytest.mark.parametrize(
|
||||
pytest.param(True, "cpu", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"gpu",
|
||||
"cuda",
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("cuda"), reason="nvidia-smi not found"
|
||||
),
|
||||
),
|
||||
pytest.param(True, "gpu", marks=pytest.mark.skip),
|
||||
pytest.param(True, "cuda", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"vulkan",
|
||||
|
||||
@@ -24,13 +24,13 @@ def get_image(url, local_filename):
|
||||
shutil.copyfileobj(res.raw, f)
|
||||
|
||||
|
||||
def compare_images(new_filename, golden_filename):
|
||||
def compare_images(new_filename, golden_filename, upload=False):
|
||||
new = np.array(Image.open(new_filename)) / 255.0
|
||||
golden = np.array(Image.open(golden_filename)) / 255.0
|
||||
diff = np.abs(new - golden)
|
||||
mean = np.mean(diff)
|
||||
if mean > 0.1:
|
||||
if os.name != "nt":
|
||||
if os.name != "nt" and upload == True:
|
||||
subprocess.run(
|
||||
[
|
||||
"gsutil",
|
||||
@@ -39,7 +39,7 @@ def compare_images(new_filename, golden_filename):
|
||||
"gs://shark_tank/testdata/builder/",
|
||||
]
|
||||
)
|
||||
raise SystemExit("new and golden not close")
|
||||
raise AssertionError("new and golden not close")
|
||||
else:
|
||||
print("SUCCESS")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
|
||||
IMPORTER=1 BENCHMARK=1 NO_BREVITAS=1 ./setup_venv.sh
|
||||
source $GITHUB_WORKSPACE/shark.venv/bin/activate
|
||||
python build_tools/stable_diffusion_testing.py --gen
|
||||
python tank/generate_sharktank.py
|
||||
|
||||
@@ -63,7 +63,14 @@ def get_inpaint_inputs():
|
||||
open("./test_images/inputs/mask.png", "wb").write(mask.content)
|
||||
|
||||
|
||||
def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
def test_loop(
|
||||
device="vulkan",
|
||||
beta=False,
|
||||
extra_flags=[],
|
||||
upload_bool=True,
|
||||
exit_on_fail=True,
|
||||
do_gen=False,
|
||||
):
|
||||
# Get golden values from tank
|
||||
shutil.rmtree("./test_images", ignore_errors=True)
|
||||
model_metrics = []
|
||||
@@ -81,6 +88,8 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
if beta:
|
||||
extra_flags.append("--beta_models=True")
|
||||
extra_flags.append("--no-progress_bar")
|
||||
if do_gen:
|
||||
extra_flags.append("--import_debug")
|
||||
to_skip = [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
@@ -181,7 +190,14 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
"./test_images/golden/" + model_name + "/*.png"
|
||||
)
|
||||
golden_file = glob(golden_path)[0]
|
||||
compare_images(test_file, golden_file)
|
||||
try:
|
||||
compare_images(
|
||||
test_file, golden_file, upload=upload_bool
|
||||
)
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
if exit_on_fail == True:
|
||||
raise
|
||||
else:
|
||||
print(command)
|
||||
print("failed to generate image for this configuration")
|
||||
@@ -200,6 +216,9 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
extra_flags.remove(
|
||||
"--iree_vulkan_target_triple=rdna2-unknown-windows"
|
||||
)
|
||||
if do_gen:
|
||||
prepare_artifacts()
|
||||
|
||||
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
|
||||
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
|
||||
f.write(header)
|
||||
@@ -218,15 +237,49 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
f.write(";".join(output) + "\n")
|
||||
|
||||
|
||||
def prepare_artifacts():
|
||||
gen_path = os.path.join(os.getcwd(), "gen_shark_tank")
|
||||
if not os.path.isdir(gen_path):
|
||||
os.mkdir(gen_path)
|
||||
for dirname in os.listdir(os.getcwd()):
|
||||
for modelname in ["clip", "unet", "vae"]:
|
||||
if modelname in dirname and "vmfb" not in dirname:
|
||||
if not os.path.isdir(os.path.join(gen_path, dirname)):
|
||||
shutil.move(os.path.join(os.getcwd(), dirname), gen_path)
|
||||
print(f"Moved dir: {dirname} to {gen_path}.")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("-d", "--device", default="vulkan")
|
||||
parser.add_argument(
|
||||
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
|
||||
)
|
||||
|
||||
parser.add_argument("-e", "--extra_args", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"-u", "--upload", action=argparse.BooleanOptionalAction, default=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--gen", action=argparse.BooleanOptionalAction, default=False
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
test_loop(args.device, args.beta, [])
|
||||
extra_args = []
|
||||
if args.extra_args:
|
||||
for arg in args.extra_args.split(","):
|
||||
extra_args.append(arg)
|
||||
test_loop(
|
||||
args.device,
|
||||
args.beta,
|
||||
extra_args,
|
||||
args.upload,
|
||||
args.exit_on_fail,
|
||||
args.gen,
|
||||
)
|
||||
if args.gen:
|
||||
prepare_artifacts()
|
||||
|
||||
@@ -27,7 +27,7 @@ include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
iree
|
||||
GIT_REPOSITORY https://github.com/nod-ai/shark-runtime.git
|
||||
GIT_REPOSITORY https://github.com/nod-ai/srt.git
|
||||
GIT_TAG shark
|
||||
GIT_SUBMODULES_RECURSE OFF
|
||||
GIT_SHALLOW OFF
|
||||
|
||||
@@ -40,7 +40,7 @@ cmake --build build/
|
||||
*Prepare the model*
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
```
|
||||
*Prepare the input*
|
||||
|
||||
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
|
||||
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
|
||||
```
|
||||
VAE and Autoencoder are also available
|
||||
```bash
|
||||
# VAE
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux vae.mlir -o vae.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
|
||||
|
||||
# CLIP Autoencoder
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
|
||||
```
|
||||
|
||||
@@ -55,7 +55,7 @@ The command line for compilation will start something like this, where the `-` n
|
||||
The `-o output_filename.vmfb` flag can be used to specify the location to save the compiled vmfb. Note that a dump of the
|
||||
dispatches that can be compiled + run in isolation can be generated by adding `--iree-hal-dump-executable-benchmarks-to=/some/directory`. Say, if they are in the `benchmarks` directory, the following compile/run commands would work for Vulkan on RDNA3.
|
||||
```
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
|
||||
|
||||
iree-benchmark-module --module=benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb --function=forward --device=vulkan
|
||||
```
|
||||
@@ -63,8 +63,8 @@ Where `${NUM}` is the dispatch number that you want to benchmark/profile in isol
|
||||
|
||||
### Enabling Tracy for Vulkan profiling
|
||||
|
||||
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime builds an
|
||||
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SHARK-Runtime/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
|
||||
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime (SRT) builds an
|
||||
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SRT/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
|
||||
```
|
||||
$env:IREE_ENABLE_RUNTIME_TRACING="ON"
|
||||
```
|
||||
|
||||
140
docs/shark_sd_koboldcpp.md
Normal file
140
docs/shark_sd_koboldcpp.md
Normal file
@@ -0,0 +1,140 @@
|
||||
# Overview
|
||||
|
||||
In [1.47.2](https://github.com/LostRuins/koboldcpp/releases/tag/v1.47.2) [Koboldcpp](https://github.com/LostRuins/koboldcpp) added AUTOMATIC1111 integration for image generation. Since SHARK implements a small subset of the A1111 REST api, you can also use SHARK for this. This document gives a starting point for how to get this working.
|
||||
|
||||
## In Action
|
||||
|
||||

|
||||
|
||||
## Memory considerations
|
||||
|
||||
Since both Koboldcpp and SHARK will use VRAM on your graphic card(s) running both at the same time using the same card will impose extra limitations on the model size you can fully offload to the video card in Koboldcpp. For me, on a RX 7900 XTX on Windows with 24 GiB of VRAM, the limit was about a 13 Billion parameter model with Q5_K_M quantisation.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
When using SHARK for image generation, especially with Koboldcpp, you need to be aware that it is currently designed to pay a large upfront cost in time compiling and tuning the model you select, to get an optimal individual image generation time. You need to be the judge as to whether this trade-off is going to be worth it for your OS and hardware combination.
|
||||
|
||||
It means that the first time you run a particular Stable Diffusion model for a particular combination of image size, LoRA, and VAE, SHARK will spend *many minutes* - even on a beefy machaine with very fast graphics card with lots of memory - building that model combination just so it can save it to disk. It may even have to go away and download the model if it doesn't already have it locally. Once it has done its build of a model combination for your hardware once, it shouldn't need to do it again until you upgrade to a newer SHARK version, install different drivers or change your graphics hardware. It will just upload the files it generated the first time to your graphics card and proceed from there.
|
||||
|
||||
This does mean however, that on a brand new fresh install of SHARK that has not generated any images on a model you haven't selected before, the first image Koboldcpp requests may look like it is *never* going finish and that the whole process has broken. Be forewarned, make yourself a cup of coffee, and expect a lot of messages about compilation and tuning from SHARK in the terminal you ran it from.
|
||||
|
||||
## Setup SHARK and prerequisites:
|
||||
|
||||
* Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme).
|
||||
* Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install.
|
||||
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_cors_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
|
||||
|
||||
```powershell
|
||||
## Run the .exe in API mode, with CORS support, on the A1111 endpoint port:
|
||||
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_cors_origin="*" --server_port=7860
|
||||
|
||||
## Run trom the base directory of a source clone of SHARK on Windows:
|
||||
.\setup_venv.ps1
|
||||
python .\apps\stable_diffusion\web\index.py --api --api_cors_origin="*" --server_port=7860
|
||||
|
||||
## Run a the base directory of a source clone of SHARK on Linux:
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
python ./apps/stable_diffusion/web/index.py --api --api_cors_origin="*" --server_port=7860
|
||||
|
||||
## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860
|
||||
|
||||
## Since the api respects most applicable SHARK command line arguments for options not specified,
|
||||
## or currently unimplemented by API, there might be some you want to set, as listed in `--help`
|
||||
.\node_ai_shark_studio_20320901_2525.exe --help
|
||||
|
||||
## For instance, the example above, but with a a custom VAE specified
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
|
||||
|
||||
## An example with multiple specific CORS origins
|
||||
python apps/stable_diffusion/web/index.py --api --api_cors_origin="koboldcpp.example.com:7001" --api_cors_origin="koboldcpp.example.com:7002" --server_port=7860
|
||||
```
|
||||
|
||||
SHARK should start in server mode, and you should see something like this:
|
||||
|
||||

|
||||
|
||||
* Note: When running in api mode with `--api`, the .exe will not function as a webUI. Thus, the address or port shown in the terminal output will only be useful for API requests.
|
||||
|
||||
|
||||
## Configure Koboldcpp for local image generation:
|
||||
|
||||
* Get the latest [Koboldcpp](https://github.com/LostRuins/koboldcpp/releases) if you don't already have it. If you have a recent AMD card that has ROCm HIP [support for Windows](https://rocmdocs.amd.com/en/latest/release/windows_support.html#windows-supported-gpus) or [support for Linux](https://rocmdocs.amd.com/en/latest/release/gpu_os_support.html#linux-supported-gpus), you'll likely prefer [YellowRosecx's ROCm fork](https://github.com/YellowRoseCx/koboldcpp-rocm).
|
||||
* Start Koboldcpp in another terminal/Powershell and setup your model configuration. Refer to the [Koboldcpp README](https://github.com/YellowRoseCx/koboldcpp-rocm) for more details on how to do this if this is your first time using Koboldcpp.
|
||||
* Once the main UI has loaded into your browser click the settings button, go to the advanced tab, and then choose *Local A1111* from the generate images dropdown:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
*if you get an error here, see the next section [below](#connecting-to-shark-on-a-different-address-or-port)*
|
||||
|
||||
* A list of Stable Diffusion models available to your SHARK instance should now be listed in the box below *generate images*. The default value will usually be set to `stabilityai/stable-diffusion-2-1-base`. Choose the model you want to use for image generation from the list (but see [performance considerations](#performance-considerations)).
|
||||
* You should now be ready to generate images, either by clicking the 'Add Img' button above the text entry box:
|
||||
|
||||

|
||||
|
||||
...or by selecting the 'Autogenerate' option in the settings:
|
||||
|
||||

|
||||
|
||||
*I often find that even if I have selected autogenerate I have to do an 'add img' to get things started off*
|
||||
|
||||
* There is one final piece of image generation configuration within Koboldcpp you might want to do. This is also in the generate images section of advanced settings. Here there is, not very obviously, a 'style' button:
|
||||
|
||||

|
||||
|
||||
This will bring up a dialog box where you can enter a short text that will sent as a prefix to the Prompt sent to SHARK:
|
||||
|
||||

|
||||
|
||||
|
||||
## Connecting to SHARK on a different address or port
|
||||
|
||||
If you didn't set the port to `--server_port=7860` when starting SHARK, or you are running it on different machine on your network than you are running Koboldcpp, or to where you are running the koboldcpp's kdlite client frontend, then you very likely got the following error:
|
||||
|
||||

|
||||
|
||||
As long as SHARK is running correctly, this means you need to set the url and port to the correct values in Koboldcpp. For instance. to set the port that Koboldcpp looks for an image generator to SHARK's default port of 8080:
|
||||
|
||||
* Select the cog icon the Generate Images section of Advanced settings:
|
||||
|
||||

|
||||
|
||||
* Then edit the port number at the end of the url in the 'A1111 Endpoint Selection' dialog box to read 8080:
|
||||
|
||||

|
||||
|
||||
* Similarly, when running SHARK on a different machine you will need to change host part of the endpoint url to the hostname or ip address where SHARK is running, similarly:
|
||||
|
||||

|
||||
|
||||
## Examples
|
||||
|
||||
Here's how Koboldcpp shows an image being requested:
|
||||
|
||||

|
||||
|
||||
The generated image in context in story mode:
|
||||
|
||||

|
||||
|
||||
And the same image when clicked on:
|
||||
|
||||

|
||||
|
||||
|
||||
## Where to find the images in SHARK
|
||||
|
||||
Even though Koboldcpp requests images at a size of 512x512, it resizes then to 256x256, converts them to `.jpeg`, and only shows them at 200x200 in the main text window. It does this so it can save them compactly embedded in your story as a `data://` uri.
|
||||
|
||||
However the images at the original size are saved by SHARK in its `output_dir` which is usually a folder named for the current date. inside `generated_imgs` folder in the SHARK installation directory.
|
||||
|
||||
You can browse these, either using the Output Gallery tab from within the SHARK web ui:
|
||||
|
||||

|
||||
|
||||
...or by browsing to the `output_dir` in your operating system's file manager:
|
||||
|
||||

|
||||
@@ -1,192 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
|
||||
project(sharkbackend LANGUAGES C CXX)
|
||||
|
||||
#
|
||||
# Options
|
||||
#
|
||||
|
||||
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
|
||||
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
|
||||
|
||||
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
|
||||
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
|
||||
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Dependencies
|
||||
#
|
||||
# FetchContent requires us to include the transitive closure of all
|
||||
# repos that we depend on so that we can override the tags.
|
||||
#
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-common
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
|
||||
GIT_TAG ${TRITON_COMMON_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-core
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
|
||||
GIT_TAG ${TRITON_CORE_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-backend
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
|
||||
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
|
||||
|
||||
#
|
||||
# The backend must be built into a shared library. Use an ldscript to
|
||||
# hide all symbols except for the TRITONBACKEND API.
|
||||
#
|
||||
configure_file(src/libtriton_dshark.ldscript libtriton_dshark.ldscript COPYONLY)
|
||||
|
||||
add_library(
|
||||
triton-dshark-backend SHARED
|
||||
src/dshark.cc
|
||||
#src/dshark_driver_module.c
|
||||
)
|
||||
|
||||
add_library(
|
||||
SharkBackend::triton-dshark-backend ALIAS triton-dshark-backend
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
triton-dshark-backend
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
|
||||
|
||||
add_subdirectory(thirdparty/shark-runtime EXCLUDE_FROM_ALL)
|
||||
|
||||
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
|
||||
iree_hal_hal
|
||||
iree_hal_cuda_cuda
|
||||
iree_hal_cuda_registration_registration
|
||||
iree_hal_vmvx_registration_registration
|
||||
iree_hal_dylib_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_hal_local_loaders_system_library_loader
|
||||
iree_hal_local_loaders_vmvx_module_loader
|
||||
)
|
||||
|
||||
target_compile_features(triton-dshark-backend PRIVATE cxx_std_11)
|
||||
|
||||
|
||||
target_link_libraries(
|
||||
triton-dshark-backend
|
||||
PRIVATE
|
||||
triton-core-serverapi # from repo-core
|
||||
triton-core-backendapi # from repo-core
|
||||
triton-core-serverstub # from repo-core
|
||||
triton-backend-utils # from repo-backend
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
set_target_properties(
|
||||
triton-dshark-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_dshark
|
||||
)
|
||||
else()
|
||||
set_target_properties(
|
||||
triton-dshark-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_dshark
|
||||
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_dshark.ldscript
|
||||
LINK_FLAGS "-Wl,--version-script libtriton_dshark.ldscript"
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
#
|
||||
# Install
|
||||
#
|
||||
include(GNUInstallDirs)
|
||||
set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/SharkBackend)
|
||||
|
||||
install(
|
||||
TARGETS
|
||||
triton-dshark-backend
|
||||
EXPORT
|
||||
triton-dshark-backend-targets
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
|
||||
)
|
||||
|
||||
install(
|
||||
EXPORT
|
||||
triton-dshark-backend-targets
|
||||
FILE
|
||||
SharkBackendTargets.cmake
|
||||
NAMESPACE
|
||||
SharkBackend::
|
||||
DESTINATION
|
||||
${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
include(CMakePackageConfigHelpers)
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_LIST_DIR}/cmake/SharkBackendConfig.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
|
||||
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
install(
|
||||
FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
|
||||
DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
#
|
||||
# Export from build tree
|
||||
#
|
||||
export(
|
||||
EXPORT triton-dshark-backend-targets
|
||||
FILE ${CMAKE_CURRENT_BINARY_DIR}/SharkBackendTargets.cmake
|
||||
NAMESPACE SharkBackend::
|
||||
)
|
||||
|
||||
export(PACKAGE SharkBackend)
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
# SHARK Triton Backend
|
||||
|
||||
The triton backend for shark.
|
||||
|
||||
# Build
|
||||
|
||||
Install SHARK
|
||||
|
||||
```
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
# skip above step if dshark is already installed
|
||||
cd SHARK/inference
|
||||
```
|
||||
|
||||
install dependancies
|
||||
|
||||
```
|
||||
apt-get install patchelf rapidjson-dev python3-dev
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
update the submodules of iree
|
||||
|
||||
```
|
||||
cd thirdparty/shark-runtime
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
Next, make the backend and install it
|
||||
|
||||
```
|
||||
cd ../..
|
||||
mkdir build && cd build
|
||||
cmake -DTRITON_ENABLE_GPU=ON \
|
||||
-DIREE_HAL_DRIVER_CUDA=ON \
|
||||
-DIREE_TARGET_BACKEND_CUDA=ON \
|
||||
-DMLIR_ENABLE_CUDA_RUNNER=ON \
|
||||
-DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \
|
||||
-DTRITON_BACKEND_REPO_TAG=r22.02 \
|
||||
-DTRITON_CORE_REPO_TAG=r22.02 \
|
||||
-DTRITON_COMMON_REPO_TAG=r22.02 ..
|
||||
make install
|
||||
```
|
||||
|
||||
# Incorporating into Triton
|
||||
|
||||
There are much more in depth explenations for the following steps in triton's documentation:
|
||||
https://github.com/triton-inference-server/server/blob/main/docs/compose.md#triton-with-unsupported-and-custom-backends
|
||||
|
||||
There should be a file at /build/install/backends/dshark/libtriton_dshark.so. You will need to copy it into your triton server image.
|
||||
More documentation is in the link above, but to create the docker image, you need to run the compose.py command in the triton-backend server repo
|
||||
|
||||
|
||||
To first build your image, clone the tritonserver repo.
|
||||
|
||||
```
|
||||
git clone https://github.com/triton-inference-server/server.git
|
||||
```
|
||||
|
||||
then run `compose.py` to build a docker compose file
|
||||
```
|
||||
cd server
|
||||
python3 compose.py --repoagent checksum --dry-run
|
||||
```
|
||||
|
||||
Because dshark is a third party backend, you will need to manually modify the `Dockerfile.compose` to include the dshark backend. To do this, in the Dockerfile.compose file produced, copy this line.
|
||||
the dshark backend will be located in the build folder from earlier under `/build/install/backends`
|
||||
|
||||
```
|
||||
COPY /path/to/build/install/backends/dshark /opt/tritonserver/backends/dshark
|
||||
```
|
||||
|
||||
Next run
|
||||
```
|
||||
docker build -t tritonserver_custom -f Dockerfile.compose .
|
||||
docker run -it --gpus=1 --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
|
||||
```
|
||||
|
||||
where `path/to/model_repos` is where you are storing the models you want to run
|
||||
|
||||
if your not using gpus, omit `--gpus=1`
|
||||
|
||||
```
|
||||
docker run -it --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
|
||||
```
|
||||
|
||||
# Setting up a model
|
||||
|
||||
to include a model in your backend, add a directory with your model name to your model repository directory. examples of models can be seen here: https://github.com/triton-inference-server/backend/tree/main/examples/model_repos/minimal_models
|
||||
|
||||
make sure to adjust the input correctly in the config.pbtxt file, and save a vmfb file under 1/model.vmfb
|
||||
|
||||
# CUDA
|
||||
|
||||
if you're having issues with cuda, make sure your correct drivers are installed, and that `nvidia-smi` works, and also make sure that the nvcc compiler is on the path.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include(CMakeFindDependencyMacro)
|
||||
|
||||
get_filename_component(
|
||||
SHARKBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${SHARKBACKEND_CMAKE_DIR})
|
||||
|
||||
if(NOT TARGET SharkBackend::triton-dshark-backend)
|
||||
include("${SHARKBACKEND_CMAKE_DIR}/SharkBackendTargets.cmake")
|
||||
endif()
|
||||
|
||||
set(SHARKBACKEND_LIBRARIES SharkBackend::triton-dshark-backend)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,30 +0,0 @@
|
||||
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
{
|
||||
global:
|
||||
TRITONBACKEND_*;
|
||||
local: *;
|
||||
};
|
||||
1
inference/thirdparty/shark-runtime
vendored
1
inference/thirdparty/shark-runtime
vendored
Submodule inference/thirdparty/shark-runtime deleted from 7b82d90c72
@@ -6,15 +6,15 @@ from distutils.sysconfig import get_python_lib
|
||||
import fileinput
|
||||
from pathlib import Path
|
||||
|
||||
# Temorary workaround for transformers/__init__.py.
|
||||
path_to_tranformers_hook = Path(
|
||||
# Temporary workaround for transformers/__init__.py.
|
||||
path_to_transformers_hook = Path(
|
||||
get_python_lib()
|
||||
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
|
||||
)
|
||||
if path_to_tranformers_hook.is_file():
|
||||
if path_to_transformers_hook.is_file():
|
||||
pass
|
||||
else:
|
||||
with open(path_to_tranformers_hook, "w") as f:
|
||||
with open(path_to_transformers_hook, "w") as f:
|
||||
f.write("module_collection_mode = 'pyz+py'")
|
||||
|
||||
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
|
||||
|
||||
@@ -5,7 +5,7 @@ requires = [
|
||||
"packaging",
|
||||
|
||||
"numpy>=1.22.4",
|
||||
"torch-mlir>=20221021.633",
|
||||
"torch-mlir>=20230620.875",
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
]
|
||||
|
||||
@@ -8,19 +8,8 @@ torchvision
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
#these dont work ok osx
|
||||
#iree-tools-tflite
|
||||
#iree-tools-xla
|
||||
#iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow-macos
|
||||
tensorflow-metal
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
tensorflow-probability
|
||||
#jax[cpu]
|
||||
|
||||
# tflitehub dependencies.
|
||||
|
||||
@@ -3,29 +3,19 @@
|
||||
|
||||
numpy>1.22.4
|
||||
pytorch-triton
|
||||
torchvision==0.16.0.dev20230322
|
||||
torchvision
|
||||
tabulate
|
||||
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
iree-tools-tflite
|
||||
iree-tools-xla
|
||||
iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
# Modelling and JAX.
|
||||
gin-config
|
||||
tensorflow>2.11
|
||||
keras
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
diffusers
|
||||
#tensorflow-probability
|
||||
#jax[cpu]
|
||||
|
||||
|
||||
# tflitehub dependencies.
|
||||
Pillow
|
||||
|
||||
# Testing and support.
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
|
||||
setuptools
|
||||
wheel
|
||||
|
||||
@@ -15,16 +18,18 @@ Pillow
|
||||
parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
tokenizers==0.13.3
|
||||
transformers
|
||||
diffusers
|
||||
#accelerate is now required for diffusers import from ckpt.
|
||||
accelerate
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
gradio==3.44.3
|
||||
altair
|
||||
omegaconf
|
||||
safetensors
|
||||
# 0.3.2 doesn't have binaries for arm64
|
||||
safetensors==0.3.1
|
||||
opencv-python
|
||||
scikit-image
|
||||
pytorch_lightning # for runwayml models
|
||||
@@ -36,10 +41,12 @@ tiktoken # for codegen
|
||||
joblib # for langchain
|
||||
timm # for MiniGPT4
|
||||
langchain
|
||||
einops # for zoedepth
|
||||
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
pyinstaller
|
||||
|
||||
# vicuna quantization
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@dev
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
|
||||
|
||||
@@ -4,7 +4,7 @@ import base64
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def upscaler_test():
|
||||
def upscaler_test(verbose=False):
|
||||
# Define values here
|
||||
prompt = ""
|
||||
negative_prompt = ""
|
||||
@@ -44,10 +44,17 @@ def upscaler_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"response from server was : {res.status_code}")
|
||||
print(
|
||||
f"[upscaler] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
|
||||
def img2img_test():
|
||||
def img2img_test(verbose=False):
|
||||
# Define values here
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
@@ -87,7 +94,16 @@ def img2img_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"response from server was : {res.status_code}")
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[img2img] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
# NOTE Uncomment below to save the picture
|
||||
|
||||
@@ -103,7 +119,7 @@ def img2img_test():
|
||||
# response_img.save(r"rest_api_tests/response_img.png")
|
||||
|
||||
|
||||
def inpainting_test():
|
||||
def inpainting_test(verbose=False):
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
@@ -150,10 +166,17 @@ def inpainting_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[Inpainting] response from server was : {res.status_code}")
|
||||
print(
|
||||
f"[inpaint] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
|
||||
def outpainting_test():
|
||||
def outpainting_test(verbose=False):
|
||||
prompt = "Paint a rabbit riding on the dog"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
@@ -200,10 +223,17 @@ def outpainting_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[Outpaint] response from server was : {res.status_code}")
|
||||
print(
|
||||
f"[outpaint] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
|
||||
def txt2img_test():
|
||||
def txt2img_test(verbose=False):
|
||||
prompt = "Paint a rabbit in a top hate"
|
||||
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
|
||||
seed = 2121991605
|
||||
@@ -232,12 +262,119 @@ def txt2img_test():
|
||||
|
||||
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
|
||||
|
||||
print(f"[txt2img] response from server was : {res.status_code}")
|
||||
print(
|
||||
f"[txt2img] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(
|
||||
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
|
||||
)
|
||||
|
||||
|
||||
def sd_models_test(verbose=False):
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/sd-models"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
res = requests.get(url=url, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[sd_models] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
|
||||
|
||||
|
||||
def sd_samplers_test(verbose=False):
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/samplers"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
res = requests.get(url=url, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[sd_samplers] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
|
||||
|
||||
|
||||
def options_test(verbose=False):
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/options"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
res = requests.get(url=url, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[options] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
|
||||
|
||||
|
||||
def cmd_flags_test(verbose=False):
|
||||
url = "http://127.0.0.1:8080/sdapi/v1/cmd-flags"
|
||||
|
||||
headers = {
|
||||
"User-Agent": "PythonTest",
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
res = requests.get(url=url, headers=headers, timeout=1000)
|
||||
|
||||
print(
|
||||
f"[cmd-flags] response from server was : {res.status_code} {res.reason}"
|
||||
)
|
||||
|
||||
if verbose or res.status_code != 200:
|
||||
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
txt2img_test()
|
||||
img2img_test()
|
||||
upscaler_test()
|
||||
inpainting_test()
|
||||
outpainting_test()
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Exercises the Stable Diffusion REST API of Shark. Make sure "
|
||||
"Shark is running in API mode on 127.0.0.1:8080 before running"
|
||||
"this script."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help=(
|
||||
"also display selected info from the JSON response for "
|
||||
"successful requests"
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sd_models_test(args.verbose)
|
||||
sd_samplers_test(args.verbose)
|
||||
options_test(args.verbose)
|
||||
cmd_flags_test(args.verbose)
|
||||
txt2img_test(args.verbose)
|
||||
img2img_test(args.verbose)
|
||||
upscaler_test(args.verbose)
|
||||
inpainting_test(args.verbose)
|
||||
outpainting_test(args.verbose)
|
||||
|
||||
@@ -90,8 +90,8 @@ python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
Write-Host "Build and installation completed successfully"
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
|
||||
@@ -86,6 +86,7 @@ $PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
|
||||
if [ "$torch_mlir_bin" = true ]; then
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
|
||||
$PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
@@ -103,7 +104,7 @@ else
|
||||
fi
|
||||
if [[ -z "${USE_IREE}" ]]; then
|
||||
rm .use-iree
|
||||
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
|
||||
RUNTIME="https://nod-ai.github.io/SRT/pip-release-links.html"
|
||||
else
|
||||
touch ./.use-iree
|
||||
RUNTIME="https://openxla.github.io/iree/pip-release-links.html"
|
||||
@@ -128,16 +129,21 @@ if [[ ! -z "${IMPORTER}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
PYTORCH_URL=https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
PYTORCH_URL=https://download.pytorch.org/whl/nightly/cpu/
|
||||
fi
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL}
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
|
||||
T_VER=$($PYTHON -m pip show torch | grep Version)
|
||||
TORCH_VERSION=${T_VER:9:17}
|
||||
T_VER_MIN=${T_VER:14:12}
|
||||
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
|
||||
TV_VERSION=${TV_VER:9:18}
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
|
||||
TV_VER_MAJ=${TV_VER:9:6}
|
||||
$PYTHON -m pip uninstall -y torchvision
|
||||
$PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu118."
|
||||
else
|
||||
@@ -145,14 +151,8 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ ! -z "${ONNX}" ]]; then
|
||||
echo "${Yellow}Installing ONNX and onnxruntime for benchmarks..."
|
||||
$PYTHON -m pip install onnx onnxruntime psutil
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully installed ONNX and ONNX runtime."
|
||||
else
|
||||
echo "Could not install ONNX." >&2
|
||||
fi
|
||||
if [[ -z "${NO_BREVITAS}" ]]; then
|
||||
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev
|
||||
fi
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
|
||||
|
||||
@@ -177,7 +177,7 @@ def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
mlir_model = str(module)
|
||||
func_name = "forward"
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
|
||||
@@ -43,9 +43,7 @@ if __name__ == "__main__":
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module = SharkInference(minilm_mlir)
|
||||
shark_module.compile()
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
|
||||
@@ -54,7 +54,7 @@ if __name__ == "__main__":
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, mlir_dialect="mhlo")
|
||||
shark_module = SharkInference(minilm_mlir, mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
output_idx = 0
|
||||
data_idx = 1
|
||||
|
||||
@@ -6,7 +6,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"
|
||||
mlir_model, device="cpu", mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
|
||||
@@ -13,9 +13,7 @@ arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
|
||||
print("Running shark on cpu backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
||||
|
||||
# Generate the random inputs and feed into the graph.
|
||||
x = shark_module.generate_random_inputs()
|
||||
@@ -23,15 +21,11 @@ shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on cuda backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module = SharkInference(mhlo_ir, device="cuda", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on vulkan backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module = SharkInference(mhlo_ir, device="vulkan", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
@@ -8,9 +8,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
)
|
||||
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module = SharkInference(mlir_model, device="cpu", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
|
||||
@@ -33,7 +33,7 @@ mlir_importer = SharkImporter(
|
||||
|
||||
print(golden_out)
|
||||
|
||||
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
|
||||
shark_module = SharkInference(vision_mlir, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input,))
|
||||
print("Obtained result", result)
|
||||
|
||||
@@ -49,9 +49,7 @@ module = torch_mlir.compile(
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cuda", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module = SharkInference(mlir_model, device="cuda", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
|
||||
|
||||
|
||||
@@ -360,7 +360,7 @@ mlir_importer = SharkImporter(
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
dlrm_mlir, func_name, device="vulkan", mlir_dialect="linalg"
|
||||
dlrm_mlir, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(input_dlrm)
|
||||
|
||||
@@ -294,7 +294,7 @@ def test_dlrm() -> None:
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
dlrm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
dlrm_mlir, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
|
||||
@@ -33,7 +33,7 @@ mlir_importer = SharkImporter(
|
||||
tracing_required=False
|
||||
)
|
||||
|
||||
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
|
||||
shark_module = SharkInference(vision_mlir, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input,))
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
|
||||
@@ -7,7 +7,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="vulkan", mlir_dialect="linalg"
|
||||
mlir_model, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
|
||||
@@ -19,9 +19,12 @@ import sys
|
||||
import subprocess
|
||||
|
||||
|
||||
def run_cmd(cmd, debug=False):
|
||||
def run_cmd(cmd, debug=False, raise_err=False):
|
||||
"""
|
||||
Inputs: cli command string.
|
||||
Inputs:
|
||||
cmd : cli command string.
|
||||
debug : if True, prints debug info
|
||||
raise_err : if True, raise exception to caller
|
||||
"""
|
||||
if debug:
|
||||
print("IREE run command: \n\n")
|
||||
@@ -39,8 +42,11 @@ def run_cmd(cmd, debug=False):
|
||||
stderr = result.stderr.decode()
|
||||
return stdout, stderr
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
if raise_err:
|
||||
raise Exception from e
|
||||
else:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
|
||||
|
||||
def iree_device_map(device):
|
||||
@@ -52,6 +58,8 @@ def iree_device_map(device):
|
||||
)
|
||||
if len(uri_parts) == 1:
|
||||
return iree_driver
|
||||
elif "rocm" in uri_parts:
|
||||
return "rocm"
|
||||
else:
|
||||
return f"{iree_driver}://{uri_parts[1]}"
|
||||
|
||||
@@ -63,7 +71,6 @@ def get_supported_device_list():
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"AMD-AIE": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
@@ -82,7 +89,6 @@ def iree_target_map(device):
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cpu-task": "llvm-cpu",
|
||||
"AMD-AIE": "llvm-cpu",
|
||||
"cpu-sync": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
@@ -95,35 +101,31 @@ _IREE_TARGET_MAP = {
|
||||
# Finds whether the required drivers are installed for the given device.
|
||||
@functools.cache
|
||||
def check_device_drivers(device):
|
||||
"""Checks necessary drivers present for gpu and vulkan devices"""
|
||||
"""
|
||||
Checks necessary drivers present for gpu and vulkan devices
|
||||
False => drivers present!
|
||||
"""
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
subprocess.check_output("nvidia-smi")
|
||||
except Exception:
|
||||
return True
|
||||
elif device in ["vulkan"]:
|
||||
try:
|
||||
subprocess.check_output("vulkaninfo")
|
||||
except Exception:
|
||||
return True
|
||||
elif device == "metal":
|
||||
return False
|
||||
elif device in ["intel-gpu"]:
|
||||
try:
|
||||
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])
|
||||
return False
|
||||
except Exception:
|
||||
return True
|
||||
elif device == "cpu":
|
||||
return False
|
||||
elif device == "rocm":
|
||||
try:
|
||||
subprocess.check_output("rocminfo")
|
||||
except Exception:
|
||||
return True
|
||||
from iree.runtime import get_driver
|
||||
|
||||
device_mapped = iree_device_map(device)
|
||||
|
||||
try:
|
||||
_ = get_driver(device_mapped)
|
||||
except ValueError as ve:
|
||||
print(
|
||||
f"[ERR] device `{device}` not registered with IREE. "
|
||||
"Ensure IREE is configured for use with this device.\n"
|
||||
f"Full Error: \n {repr(ve)}"
|
||||
)
|
||||
return True
|
||||
except RuntimeError as re:
|
||||
print(
|
||||
f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}"
|
||||
)
|
||||
return True
|
||||
|
||||
# Unknown device. We assume drivers are installed.
|
||||
return False
|
||||
@@ -131,11 +133,32 @@ def check_device_drivers(device):
|
||||
|
||||
# Installation info for the missing device drivers.
|
||||
def device_driver_info(device):
|
||||
if device == "cuda":
|
||||
return "nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
|
||||
elif device in ["metal", "vulkan"]:
|
||||
return "vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
|
||||
elif device == "rocm":
|
||||
return "rocm info not found. Please install rocm"
|
||||
device_driver_err_map = {
|
||||
"cuda": {
|
||||
"debug": "Try `nvidia-smi` on system to check.",
|
||||
"solution": " from https://www.nvidia.in/Download/index.aspx?lang=en-in for your system.",
|
||||
},
|
||||
"vulkan": {
|
||||
"debug": "Try `vulkaninfo` on system to check.",
|
||||
"solution": " from https://vulkan.lunarg.com/sdk/home for your distribution.",
|
||||
},
|
||||
"metal": {
|
||||
"debug": "Check if Bare metal is supported and enabled on your system.",
|
||||
"solution": ".",
|
||||
},
|
||||
"rocm": {
|
||||
"debug": f"Try `{'hip' if sys.platform == 'win32' else 'rocm'}info` on system to check.",
|
||||
"solution": " from https://rocm.docs.amd.com/en/latest/rocm.html for your system.",
|
||||
},
|
||||
}
|
||||
|
||||
if device in device_driver_err_map:
|
||||
err_msg = (
|
||||
f"Required drivers for {device} not found. {device_driver_err_map[device]['debug']} "
|
||||
f"Please install the required drivers{device_driver_err_map[device]['solution']} "
|
||||
f"For further assistance please reach out to the community on discord [https://discord.com/invite/RUqY2h2s9u]"
|
||||
f" and/or file a bug at https://github.com/nod-ai/SHARK/issues"
|
||||
)
|
||||
return err_msg
|
||||
else:
|
||||
return f"{device} is not supported."
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
|
||||
from shark.iree_utils._common import run_cmd, iree_device_map
|
||||
from shark.iree_utils.cpu_utils import get_cpu_count
|
||||
import numpy as np
|
||||
@@ -62,16 +61,12 @@ def build_benchmark_args(
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = benchmark_module.__path__[0]
|
||||
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
|
||||
if platform.system() == "Windows":
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module.exe"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
|
||||
time_extractor = None
|
||||
else:
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
@@ -106,15 +101,13 @@ def build_benchmark_args_non_tensor_input(
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = benchmark_module.__path__[0]
|
||||
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
|
||||
if platform.system() == "Windows":
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module.exe"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
|
||||
time_extractor = None
|
||||
else:
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module"
|
||||
)
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
if function_name:
|
||||
@@ -139,7 +132,7 @@ def run_benchmark_module(benchmark_cl):
|
||||
benchmark_path = benchmark_cl[0]
|
||||
assert os.path.exists(
|
||||
benchmark_path
|
||||
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
|
||||
), "Cannot find iree_benchmark_module, Please contact SHARK maintainer on discord."
|
||||
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
|
||||
try:
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
|
||||
@@ -34,19 +34,27 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
print("Configuring for device:" + device)
|
||||
device_uri = device.split("://")
|
||||
if len(device_uri) > 1:
|
||||
if device_uri[0] not in ["vulkan"]:
|
||||
if device_uri[0] not in ["vulkan", "rocm"]:
|
||||
print(
|
||||
f"Specific device selection only supported for vulkan now."
|
||||
f"Specific device selection only supported for vulkan and rocm."
|
||||
f"Proceeding with {device} as device."
|
||||
)
|
||||
device_num = device_uri[1]
|
||||
# device_uri can be device_num or device_path.
|
||||
# assuming number of devices for a single driver will be not be >99
|
||||
if len(device_uri[1]) <= 2:
|
||||
# expected to be device index in range 0 - 99
|
||||
device_num = int(device_uri[1])
|
||||
else:
|
||||
# expected to be device path
|
||||
device_num = device_uri[1]
|
||||
|
||||
else:
|
||||
device_num = 0
|
||||
|
||||
if device_uri[0] == "cpu":
|
||||
if "cpu" in device:
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
|
||||
data_tiling_flag = ["--iree-opt-data-tiling"]
|
||||
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
|
||||
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
|
||||
|
||||
@@ -55,6 +63,8 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
+ data_tiling_flag
|
||||
+ u_kernel_flag
|
||||
+ stack_size_flag
|
||||
+ ["--iree-flow-enable-quantized-matmul-reassociation"]
|
||||
+ ["--iree-llvmcpu-enable-quantized-matmul-reassociation"]
|
||||
)
|
||||
if device_uri[0] == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
@@ -73,7 +83,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args()
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
|
||||
return []
|
||||
|
||||
|
||||
@@ -84,7 +94,7 @@ def get_iree_frontend_args(frontend):
|
||||
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
return [
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-flow-demote-i64-to-i32",
|
||||
"--iree-input-demote-i64-to-i32",
|
||||
]
|
||||
else:
|
||||
# Frontend not found.
|
||||
@@ -92,13 +102,27 @@ def get_iree_frontend_args(frontend):
|
||||
|
||||
|
||||
# Common args to be used given any frontend or device.
|
||||
def get_iree_common_args():
|
||||
return [
|
||||
"--iree-stream-resource-index-bits=64",
|
||||
"--iree-vm-target-index-bits=64",
|
||||
def get_iree_common_args(debug=False):
|
||||
common_args = [
|
||||
"--iree-stream-resource-max-allocation-size=4294967295",
|
||||
"--iree-vm-bytecode-module-strip-source-map=true",
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
]
|
||||
if debug == True:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=false",
|
||||
"--verify=true",
|
||||
]
|
||||
)
|
||||
else:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=true",
|
||||
"--verify=false",
|
||||
]
|
||||
)
|
||||
return common_args
|
||||
|
||||
|
||||
# Args that are suitable only for certain models or groups of models.
|
||||
@@ -277,14 +301,17 @@ def compile_module_to_flatbuffer(
|
||||
model_config_path,
|
||||
extra_args,
|
||||
model_name="None",
|
||||
debug=False,
|
||||
compile_str=False,
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
input_type = "auto"
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args()
|
||||
args += get_iree_common_args(debug=debug)
|
||||
args += get_model_specific_args()
|
||||
args += extra_args
|
||||
args += shark_args.additional_compile_args
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "auto"
|
||||
@@ -295,10 +322,7 @@ def compile_module_to_flatbuffer(
|
||||
elif frontend in ["tm_tensor"]:
|
||||
input_type = ireec.InputType.TM_TENSOR
|
||||
|
||||
# TODO: make it simpler.
|
||||
# Compile according to the input type, else just try compiling.
|
||||
if input_type != "":
|
||||
# Currently for MHLO/TOSA.
|
||||
if compile_str:
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[iree_target_map(device)],
|
||||
@@ -306,9 +330,10 @@ def compile_module_to_flatbuffer(
|
||||
input_type=input_type,
|
||||
)
|
||||
else:
|
||||
# Currently for Torch.
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
assert os.path.isfile(module)
|
||||
flatbuffer_blob = ireec.compile_file(
|
||||
str(module),
|
||||
input_type=input_type,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=args,
|
||||
)
|
||||
@@ -316,8 +341,12 @@ def compile_module_to_flatbuffer(
|
||||
return flatbuffer_blob
|
||||
|
||||
|
||||
def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
def get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=None, rt_flags: list = []
|
||||
):
|
||||
# Returns the compiled module and the configs.
|
||||
for flag in rt_flags:
|
||||
ireert.flags.parse_flag(flag)
|
||||
if device_idx is not None:
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
@@ -339,10 +368,24 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
|
||||
|
||||
def load_vmfb_using_mmap(
|
||||
flatbuffer_blob_or_path, device: str, device_idx: int = None
|
||||
flatbuffer_blob_or_path,
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
print(f"Loading module {flatbuffer_blob_or_path}...")
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
rt_flags.append(flag)
|
||||
for flag in rt_flags:
|
||||
print(flag)
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
if "rocm" in device:
|
||||
device = "rocm"
|
||||
with DetailLogger(timeout=2.5) as dl:
|
||||
# First get configs.
|
||||
if device_idx is not None:
|
||||
@@ -357,6 +400,9 @@ def load_vmfb_using_mmap(
|
||||
)
|
||||
dl.log(f"ireert.create_device()")
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
dl.log(f"ireert.Config()")
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
@@ -367,6 +413,7 @@ def load_vmfb_using_mmap(
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
@@ -386,7 +433,14 @@ def load_vmfb_using_mmap(
|
||||
)
|
||||
dl.log(f"mmap {flatbuffer_blob_or_path}")
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
for flag in shark_args.additional_runtime_args:
|
||||
ireert.flags.parse_flags(flag)
|
||||
dl.log(f"ireert.SystemContext created")
|
||||
if "vulkan" in device:
|
||||
# Vulkan pipeline creation consumes significant amount of time.
|
||||
print(
|
||||
"\tCompiling Vulkan shaders. This may take a few minutes."
|
||||
)
|
||||
ctx.add_vm_module(mmaped_vmfb)
|
||||
dl.log(f"module initialized")
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
@@ -407,12 +461,21 @@ def get_iree_compiled_module(
|
||||
frontend: str = "torch",
|
||||
model_config_path: str = None,
|
||||
extra_args: list = [],
|
||||
rt_flags: list = [],
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
debug: bool = False,
|
||||
compile_str: bool = False,
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, model_config_path, extra_args
|
||||
module=module,
|
||||
device=device,
|
||||
frontend=frontend,
|
||||
model_config_path=model_config_path,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=compile_str,
|
||||
)
|
||||
temp_file_to_unlink = None
|
||||
# TODO: Currently mmap=True control flow path has been switched off for mmap.
|
||||
@@ -421,11 +484,14 @@ def get_iree_compiled_module(
|
||||
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
|
||||
if mmap:
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_blob, device, device_idx
|
||||
flatbuffer_blob, device, device_idx, rt_flags
|
||||
)
|
||||
else:
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=device_idx
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=device_idx,
|
||||
rt_flags=rt_flags,
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
@@ -440,17 +506,21 @@ def load_flatbuffer(
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
temp_file_to_unlink = None
|
||||
if mmap:
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_path, device, device_idx
|
||||
flatbuffer_path, device, device_idx, rt_flags
|
||||
)
|
||||
else:
|
||||
with open(os.path.join(flatbuffer_path), "rb") as f:
|
||||
flatbuffer_blob = f.read()
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=device_idx
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=device_idx,
|
||||
rt_flags=rt_flags,
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
@@ -468,10 +538,18 @@ def export_iree_module_to_vmfb(
|
||||
model_config_path: str = None,
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
debug: bool = False,
|
||||
compile_str: bool = False,
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, mlir_dialect, model_config_path, extra_args
|
||||
module=module,
|
||||
device=device,
|
||||
frontend=mlir_dialect,
|
||||
model_config_path=model_config_path,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=compile_str,
|
||||
)
|
||||
if module_name is None:
|
||||
device_name = (
|
||||
@@ -479,9 +557,9 @@ def export_iree_module_to_vmfb(
|
||||
)
|
||||
module_name = f"{mlir_dialect}_{device_name}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
with open(filename, "wb") as f:
|
||||
f.write(flatbuffer_blob)
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
return filename
|
||||
|
||||
|
||||
@@ -507,10 +585,17 @@ def get_results(
|
||||
frontend="torch",
|
||||
send_to_host=True,
|
||||
debug_timeout: float = 5.0,
|
||||
device: str = None,
|
||||
):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
with DetailLogger(debug_timeout) as dl:
|
||||
device_inputs = []
|
||||
if device == "rocm" and hasattr(config, "id"):
|
||||
haldriver = ireert.get_driver("rocm")
|
||||
haldevice = haldriver.create_device(
|
||||
config.id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
for input_array in input:
|
||||
dl.log(f"Load to device: {input_array.shape}")
|
||||
device_inputs.append(
|
||||
@@ -547,9 +632,17 @@ def get_results(
|
||||
def get_iree_runtime_config(device):
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
if "metal" in device and shark_args.device_allocator == "caching":
|
||||
print(
|
||||
"[WARNING] metal devices can not have a `caching` allocator."
|
||||
"\nUsing default allocator `None`"
|
||||
)
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
allocators=shark_args.device_allocator,
|
||||
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
|
||||
allocators=shark_args.device_allocator
|
||||
if "metal" not in device
|
||||
else None,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
return config
|
||||
|
||||
@@ -17,7 +17,12 @@
|
||||
import functools
|
||||
import iree.runtime as ireert
|
||||
import ctypes
|
||||
import sys
|
||||
from subprocess import CalledProcessError
|
||||
from shark.parser import shark_args
|
||||
from shark.iree_utils._common import run_cmd
|
||||
|
||||
# TODO: refactor to rocm and cuda utils
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
@@ -38,26 +43,85 @@ def get_iree_gpu_args():
|
||||
return []
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
@functools.cache
|
||||
def get_iree_rocm_args():
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
# get arch from rocminfo.
|
||||
import re
|
||||
import subprocess
|
||||
def check_rocm_device_arch_in_args(extra_args):
|
||||
# Check if the target arch flag for rocm device present in extra_args
|
||||
for flag in extra_args:
|
||||
if "iree-rocm-target-chip" in flag:
|
||||
flag_arch = flag.split("=")[1]
|
||||
return flag_arch
|
||||
return None
|
||||
|
||||
rocm_arch = re.match(
|
||||
r".*(gfx\w+)",
|
||||
subprocess.check_output(
|
||||
"rocminfo | grep -i 'gfx'", shell=True, text=True
|
||||
),
|
||||
).group(1)
|
||||
print(f"Found rocm arch {rocm_arch}...")
|
||||
return [
|
||||
f"--iree-rocm-target-chip={rocm_arch}",
|
||||
"--iree-rocm-link-bc=true",
|
||||
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
|
||||
]
|
||||
|
||||
def get_rocm_device_arch(device_num=0, extra_args=[]):
|
||||
# ROCM Device Arch selection:
|
||||
# 1 : User given device arch using `--iree-rocm-target-chip` flag
|
||||
# 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index <device_num>
|
||||
# 3 : default arch : gfx1100
|
||||
|
||||
arch_in_flag = check_rocm_device_arch_in_args(extra_args)
|
||||
if arch_in_flag is not None:
|
||||
print(
|
||||
f"User Specified rocm target device arch from flag : {arch_in_flag} will be used"
|
||||
)
|
||||
return arch_in_flag
|
||||
|
||||
arch_in_device_dump = None
|
||||
|
||||
# get rocm arch from iree dump devices
|
||||
def get_devices_info_from_dump(dump):
|
||||
from os import linesep
|
||||
|
||||
dump_clean = list(
|
||||
filter(
|
||||
lambda s: "--device=rocm" in s or "gpu-arch-name:" in s,
|
||||
dump.split(linesep),
|
||||
)
|
||||
)
|
||||
arch_pairs = [
|
||||
(
|
||||
dump_clean[i].split("=")[1].strip(),
|
||||
dump_clean[i + 1].split(":")[1].strip(),
|
||||
)
|
||||
for i in range(0, len(dump_clean), 2)
|
||||
]
|
||||
return arch_pairs
|
||||
|
||||
dump_device_info = None
|
||||
try:
|
||||
dump_device_info = run_cmd(
|
||||
"iree-run-module --dump_devices=rocm", raise_err=True
|
||||
)
|
||||
except Exception as e:
|
||||
print("could not execute `iree-run-module --dump_devices=rocm`")
|
||||
|
||||
if dump_device_info is not None:
|
||||
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
|
||||
if len(device_arch_pairs) > device_num: # can find arch in the list
|
||||
arch_in_device_dump = device_arch_pairs[device_num][1]
|
||||
|
||||
if arch_in_device_dump is not None:
|
||||
print(f"Found ROCm device arch : {arch_in_device_dump}")
|
||||
return arch_in_device_dump
|
||||
|
||||
default_rocm_arch = "gfx_1100"
|
||||
print(
|
||||
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
|
||||
"\n or from `iree-run-module --dump_devices=rocm` command."
|
||||
f"\nUsing {default_rocm_arch} as ROCm arch for compilation."
|
||||
)
|
||||
return default_rocm_arch
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
def get_iree_rocm_args(device_num=0, extra_args=[]):
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
rocm_flags = ["--iree-rocm-link-bc=true"]
|
||||
|
||||
if check_rocm_device_arch_in_args(extra_args) is None:
|
||||
rocm_arch = get_rocm_device_arch(device_num, extra_args)
|
||||
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
|
||||
|
||||
return rocm_flags
|
||||
|
||||
|
||||
# Some constants taken from cuda.h
|
||||
|
||||
@@ -89,24 +89,10 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
|
||||
|
||||
def get_iree_metal_args(device_num=0, extra_args=[]):
|
||||
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
# Add any metal spefic compilation flags here
|
||||
res_metal_flag = []
|
||||
metal_triple_flag = None
|
||||
for arg in extra_args:
|
||||
if "-iree-metal-target-platform=" in arg:
|
||||
print(f"Using target triple {arg} from command line args")
|
||||
metal_triple_flag = arg
|
||||
break
|
||||
|
||||
if metal_triple_flag is None:
|
||||
metal_triple_flag = get_metal_triple_flag(extra_args=extra_args)
|
||||
|
||||
if metal_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(
|
||||
"-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
)
|
||||
res_metal_flag.append(vulkan_target_env)
|
||||
if len(extra_args) > 0:
|
||||
res_metal_flag.extend(extra_args)
|
||||
return res_metal_flag
|
||||
|
||||
|
||||
|
||||
@@ -57,11 +57,8 @@ def get_version(triple):
|
||||
@functools.cache
|
||||
def get_extensions(triple):
|
||||
def make_ext_list(ext_list):
|
||||
res = ""
|
||||
for e in ext_list:
|
||||
res += e + ", "
|
||||
res = f"[{res[:-2]}]"
|
||||
return res
|
||||
res = ", ".join(ext_list)
|
||||
return f"[{res}]"
|
||||
|
||||
arch, product, os = triple
|
||||
if arch == "m1":
|
||||
@@ -119,7 +116,7 @@ def get_extensions(triple):
|
||||
]
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("VK_NV_cooperative_matrix")
|
||||
ext.append("VK_KHR_cooperative_matrix")
|
||||
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
|
||||
ext.append("VK_KHR_shader_integer_dot_product")
|
||||
return make_ext_list(ext_list=ext)
|
||||
@@ -247,7 +244,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>"
|
||||
]
|
||||
|
||||
if product == "rx5700xt":
|
||||
@@ -468,9 +465,9 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
]
|
||||
|
||||
elif arch == "adreno":
|
||||
@@ -531,7 +528,7 @@ def get_vulkan_target_capabilities(triple):
|
||||
cmc = ""
|
||||
for case in v:
|
||||
cmc += f"#vk.coop_matrix_props<{case}>, "
|
||||
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
|
||||
res += f"cooperativeMatrixPropertiesKHR = [{cmc[:-2]}], "
|
||||
else:
|
||||
res += f"{k} = {get_comma_sep_str(v)}, "
|
||||
else:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user