Compare commits
333 Commits
ryan/flux
...
ryan/flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ff5355ce3 | ||
|
|
3ad5fc060d | ||
|
|
17d5c85454 | ||
|
|
4698649cc9 | ||
|
|
d41c075768 | ||
|
|
6b129aaba6 | ||
|
|
f58546fd53 | ||
|
|
de3edf47fb | ||
|
|
6dc4baa925 | ||
|
|
943fa6da4b | ||
|
|
bfe31838cc | ||
|
|
16b76f7e7f | ||
|
|
cb115743e7 | ||
|
|
8f4279ba51 | ||
|
|
22a207b50d | ||
|
|
638c6003e3 | ||
|
|
8d35af946e | ||
|
|
24065ec6b6 | ||
|
|
627b0bf644 | ||
|
|
b43da46b82 | ||
|
|
4255a01c64 | ||
|
|
23adbd4002 | ||
|
|
fb5a24fcc6 | ||
|
|
cfdd5a1900 | ||
|
|
2313f326df | ||
|
|
2e092a2313 | ||
|
|
763ef06c18 | ||
|
|
8292f6cd42 | ||
|
|
278bba499e | ||
|
|
dd99ed28e0 | ||
|
|
9a8aca69bf | ||
|
|
7ad62512eb | ||
|
|
bd466661ec | ||
|
|
7ebb509d05 | ||
|
|
0aa13c046c | ||
|
|
a7a33d73f5 | ||
|
|
ffa39857d3 | ||
|
|
e85c3bc465 | ||
|
|
8185ba7054 | ||
|
|
d501865bec | ||
|
|
d62310bb5f | ||
|
|
1835bff196 | ||
|
|
87261bdbc9 | ||
|
|
4e4b6c6dbc | ||
|
|
5e8cf9fb6a | ||
|
|
c738fe051f | ||
|
|
29fe1533f2 | ||
|
|
77090070bd | ||
|
|
6ba9b1b6b0 | ||
|
|
c578b8df1e | ||
|
|
cad9a41433 | ||
|
|
5fefb3b0f4 | ||
|
|
5284a870b0 | ||
|
|
e064377c05 | ||
|
|
3e569c8312 | ||
|
|
16825ee6e9 | ||
|
|
3f5340fa53 | ||
|
|
f2a1a39b33 | ||
|
|
326de55d3e | ||
|
|
b2df909570 | ||
|
|
026ac36b06 | ||
|
|
92125e5fd2 | ||
|
|
c0c139da88 | ||
|
|
404ad6a7fd | ||
|
|
fc39086fb4 | ||
|
|
cd215700fe | ||
|
|
e97fd85904 | ||
|
|
0a263fa5b1 | ||
|
|
fae3836a8d | ||
|
|
b3d2eb4178 | ||
|
|
576f1cbb75 | ||
|
|
50085b40bb | ||
|
|
cff382715a | ||
|
|
54d54d1bf2 | ||
|
|
e84ea68282 | ||
|
|
160dd36782 | ||
|
|
65bb46bcca | ||
|
|
2d185fb766 | ||
|
|
2ba9b02932 | ||
|
|
849da67cc7 | ||
|
|
3ea6c9666e | ||
|
|
cf633e4ef2 | ||
|
|
bbf934d980 | ||
|
|
620f733110 | ||
|
|
67928609a3 | ||
|
|
5f15afb7db | ||
|
|
635d2f480d | ||
|
|
70c278c810 | ||
|
|
56b9906e2e | ||
|
|
a808ce81fd | ||
|
|
83f82c5ddf | ||
|
|
101de8c25d | ||
|
|
3339a4baf0 | ||
|
|
dff4a88baa | ||
|
|
a21f6c4964 | ||
|
|
97562504b7 | ||
|
|
75d8ac378c | ||
|
|
b9dd354e2b | ||
|
|
33c2fbd201 | ||
|
|
5063be92bf | ||
|
|
1047584b3e | ||
|
|
6764dcfdaa | ||
|
|
012864ceb1 | ||
|
|
a0bf20bcee | ||
|
|
14ab339b33 | ||
|
|
25c91efbb6 | ||
|
|
1c1f2c6664 | ||
|
|
d7c22b3bf7 | ||
|
|
185f2a395f | ||
|
|
0c5649491e | ||
|
|
94aba5892a | ||
|
|
ef093dde29 | ||
|
|
34451e5f27 | ||
|
|
1f9bdd1a9a | ||
|
|
c27d59baf7 | ||
|
|
f130ddec7c | ||
|
|
a0a259eef1 | ||
|
|
b66f19d4d1 | ||
|
|
4105a78b83 | ||
|
|
19a68afb3a | ||
|
|
fd68a2475b | ||
|
|
28ff7ba830 | ||
|
|
5d0b248fdb | ||
|
|
01a4e0f6ef | ||
|
|
91e0731506 | ||
|
|
d1f904d41f | ||
|
|
269388c9f4 | ||
|
|
b8486379ce | ||
|
|
400eb94d3b | ||
|
|
e210c96485 | ||
|
|
5f567f41f4 | ||
|
|
5fed573a29 | ||
|
|
cfac7c8189 | ||
|
|
1787de6836 | ||
|
|
ac96f187bd | ||
|
|
72398350b4 | ||
|
|
df9445c351 | ||
|
|
87b7a2e39b | ||
|
|
f7e46622a1 | ||
|
|
71f18353a9 | ||
|
|
4228de707b | ||
|
|
b6a05629ef | ||
|
|
fbaa820643 | ||
|
|
db2a2d5e38 | ||
|
|
8ba6e6b1f8 | ||
|
|
57168d719b | ||
|
|
dee6d2c98e | ||
|
|
e49105ece5 | ||
|
|
0c5e11f521 | ||
|
|
a63f842a13 | ||
|
|
4bd7fda694 | ||
|
|
81f0886d6f | ||
|
|
2eb87f3306 | ||
|
|
723f3ab0a9 | ||
|
|
1bd90e0fd4 | ||
|
|
436f18ff55 | ||
|
|
cde9696214 | ||
|
|
2d9042fb93 | ||
|
|
9ed53af520 | ||
|
|
56fda669fd | ||
|
|
1d8545a76c | ||
|
|
5f59a828f9 | ||
|
|
1fa6bddc89 | ||
|
|
d3a5ca5247 | ||
|
|
f01f56a98e | ||
|
|
99b0f79784 | ||
|
|
e1eb104345 | ||
|
|
5c2f95ef50 | ||
|
|
b63df9bab9 | ||
|
|
a52c899c6d | ||
|
|
eeabb7ebe5 | ||
|
|
8b1cef978c | ||
|
|
152da482cd | ||
|
|
3cf0365a35 | ||
|
|
5870742bb9 | ||
|
|
01d8c62c57 | ||
|
|
55a242b2d6 | ||
|
|
45263b339f | ||
|
|
3319491861 | ||
|
|
e687afac90 | ||
|
|
b39031ea53 | ||
|
|
0b77511271 | ||
|
|
c99cd989c1 | ||
|
|
317fdadb21 | ||
|
|
4e294f9e3e | ||
|
|
526e0f30a0 | ||
|
|
231e5ec94a | ||
|
|
e5bb6f9693 | ||
|
|
da7dee44c6 | ||
|
|
83144f4fe3 | ||
|
|
c451f52ea3 | ||
|
|
8a2c78f2e1 | ||
|
|
bcc78bde9b | ||
|
|
054bb6fe0a | ||
|
|
4f4aa6d92e | ||
|
|
eac51ac6f5 | ||
|
|
9f349a7c0a | ||
|
|
918afa5b15 | ||
|
|
eb1113f95c | ||
|
|
4f4ba7b462 | ||
|
|
2298be0e6b | ||
|
|
63494dfca7 | ||
|
|
36a1d39454 | ||
|
|
a6f6d5c400 | ||
|
|
e85f221aca | ||
|
|
d4797e37dc | ||
|
|
3e7923d072 | ||
|
|
a85d69ce3d | ||
|
|
96db006c99 | ||
|
|
8ca57d03d8 | ||
|
|
6c404ce5f8 | ||
|
|
584e07182b | ||
|
|
f787e9acf6 | ||
|
|
5a24b89e54 | ||
|
|
9b482e2a4f | ||
|
|
df4dbe2d57 | ||
|
|
713bd11177 | ||
|
|
182571df4b | ||
|
|
29bfe492b6 | ||
|
|
3fb4e3050c | ||
|
|
39c7ec3cd9 | ||
|
|
26bfbdec7f | ||
|
|
7a3eaa8da9 | ||
|
|
599db7296f | ||
|
|
042aab4295 | ||
|
|
24f298283f | ||
|
|
68dac6349d | ||
|
|
b675fc19e8 | ||
|
|
659019cfd6 | ||
|
|
dcd61e1f82 | ||
|
|
f5c99b1488 | ||
|
|
810be3e1d4 | ||
|
|
60d754d1df | ||
|
|
bd07c86db9 | ||
|
|
bcbf8b6bd8 | ||
|
|
356661459b | ||
|
|
deb917825e | ||
|
|
15415c6d85 | ||
|
|
76b0380b5f | ||
|
|
2d58754789 | ||
|
|
9cdf1f599c | ||
|
|
268be97ba0 | ||
|
|
a9014673a0 | ||
|
|
d36c43a10f | ||
|
|
54a5c4e482 | ||
|
|
5e09a244e3 | ||
|
|
88648dca1a | ||
|
|
8840df2b00 | ||
|
|
af159acbdf | ||
|
|
471719bbbe | ||
|
|
b126f2ffd5 | ||
|
|
9938f12ef0 | ||
|
|
982c266073 | ||
|
|
5c37391883 | ||
|
|
ddeafc6833 | ||
|
|
41b2d5d013 | ||
|
|
29d6f48901 | ||
|
|
d5c9f4e47f | ||
|
|
24d73387d8 | ||
|
|
e0d3927265 | ||
|
|
e5f7c2a9b7 | ||
|
|
b0760710d5 | ||
|
|
764accc921 | ||
|
|
6a01fce9c1 | ||
|
|
9c732ac3b1 | ||
|
|
b70891c661 | ||
|
|
4dbf851741 | ||
|
|
6c927a9fd4 | ||
|
|
096f001634 | ||
|
|
4837e578b2 | ||
|
|
1e547ef912 | ||
|
|
f6b8970bd1 | ||
|
|
29325a7214 | ||
|
|
8ecf72838d | ||
|
|
c3ab8a6aa8 | ||
|
|
1931aa3e70 | ||
|
|
d3d8055055 | ||
|
|
476b0a0403 | ||
|
|
f66584713c | ||
|
|
33624fc2fa | ||
|
|
41c3e73a3c | ||
|
|
97553a7de2 | ||
|
|
12ba15bfa9 | ||
|
|
09d1e190e7 | ||
|
|
8eb5d08499 | ||
|
|
9be6acde7d | ||
|
|
5f83bb0069 | ||
|
|
b138882abc | ||
|
|
0cd7cdb52e | ||
|
|
1d8b7e2bcf | ||
|
|
6461f4758d | ||
|
|
3189ab6863 | ||
|
|
3f9a674d4b | ||
|
|
587f59b25b | ||
|
|
4952eada87 | ||
|
|
581029ebaa | ||
|
|
42d68780de | ||
|
|
28032a2f80 | ||
|
|
e381e021e9 | ||
|
|
641af64f93 | ||
|
|
a7b83c8b5b | ||
|
|
4cc41e0188 | ||
|
|
442fc02429 | ||
|
|
9a4d075074 | ||
|
|
17ff8196cb | ||
|
|
68f993998a | ||
|
|
7da6120b39 | ||
|
|
6cd40965c4 | ||
|
|
408a1d6dbb | ||
|
|
0b0abfbe8f | ||
|
|
cc96dcf0ed | ||
|
|
2604fd9fde | ||
|
|
857d74bbfe | ||
|
|
fd7a635777 | ||
|
|
af9110e964 | ||
|
|
a61209206b | ||
|
|
e05cc62e5f | ||
|
|
4f8a4b0f22 | ||
|
|
a743f3c9b5 | ||
|
|
217fe40d99 | ||
|
|
b76bf50b93 | ||
|
|
332bc9da5b | ||
|
|
08def3da95 | ||
|
|
daf899f9c4 | ||
|
|
13fb2d1f49 | ||
|
|
95dde802ea | ||
|
|
b4cf78a95d | ||
|
|
18f89ed5ed | ||
|
|
f170697ebe | ||
|
|
556c6a1d84 | ||
|
|
e5d9ca013e | ||
|
|
4166c756ce | ||
|
|
4f0dfbd34d |
2
.github/workflows/python-checks.yml
vendored
@@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip install ruff
|
||||
run: pip install ruff==0.6.0
|
||||
shell: bash
|
||||
|
||||
- name: ruff check
|
||||
|
||||
2
.github/workflows/python-tests.yml
vendored
@@ -60,7 +60,7 @@ jobs:
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: macos-default
|
||||
os: macOS-12
|
||||
os: macOS-14
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: windows-cpu
|
||||
os: windows-2022
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
# Invoke in Docker
|
||||
|
||||
- Ensure that Docker can use the GPU on your system
|
||||
- This documentation assumes Linux, but should work similarly under Windows with WSL2
|
||||
First things first:
|
||||
|
||||
- Ensure that Docker can use your [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] GPU.
|
||||
- This document assumes a Linux system, but should work similarly under Windows with WSL2.
|
||||
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
|
||||
|
||||
## Quickstart :lightning:
|
||||
## Quickstart
|
||||
|
||||
No `docker compose`, no persistence, just a simple one-liner using the official images:
|
||||
No `docker compose`, no persistence, single command, using the official images:
|
||||
|
||||
**CUDA:**
|
||||
**CUDA (NVIDIA GPU):**
|
||||
|
||||
```bash
|
||||
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
|
||||
```
|
||||
|
||||
**ROCm:**
|
||||
**ROCm (AMD GPU):**
|
||||
|
||||
```bash
|
||||
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
|
||||
@@ -22,12 +24,20 @@ docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invok
|
||||
|
||||
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
|
||||
|
||||
> [!TIP]
|
||||
> To persist your data (including downloaded models) outside of the container, add a `--volume/-v` flag to the above command, e.g.: `docker run --volume /some/local/path:/invokeai <...the rest of the command>`
|
||||
### Data persistence
|
||||
|
||||
To persist your generated images and downloaded models outside of the container, add a `--volume/-v` flag to the above command, e.g.:
|
||||
|
||||
```bash
|
||||
docker run --volume /some/local/path:/invokeai {...etc...}
|
||||
```
|
||||
|
||||
`/some/local/path/invokeai` will contain all your data.
|
||||
It can *usually* be reused between different installs of Invoke. Tread with caution and read the release notes!
|
||||
|
||||
## Customize the container
|
||||
|
||||
We ship the `run.sh` script, which is a convenient wrapper around `docker compose` for cases where custom image build args are needed. Alternatively, the familiar `docker compose` commands work just as well.
|
||||
The included `run.sh` script is a convenience wrapper around `docker compose`. It can be helpful for passing additional build arguments to `docker compose`. Alternatively, the familiar `docker compose` commands work just as well.
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
@@ -38,11 +48,14 @@ cp .env.sample .env
|
||||
|
||||
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
|
||||
|
||||
>[!TIP]
|
||||
>When using the `run.sh` script, the container will continue running after Ctrl+C. To shut it down, use the `docker compose down` command.
|
||||
|
||||
## Docker setup in detail
|
||||
|
||||
#### Linux
|
||||
|
||||
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
|
||||
1. Ensure buildkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
|
||||
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
|
||||
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
|
||||
3. Ensure docker daemon is able to access the GPU.
|
||||
@@ -98,25 +111,7 @@ GPU_DRIVER=cuda
|
||||
|
||||
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
|
||||
|
||||
## Even More Customizing!
|
||||
---
|
||||
|
||||
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
|
||||
|
||||
### Reconfigure the runtime directory
|
||||
|
||||
Can be used to download additional models from the supported model list
|
||||
|
||||
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
|
||||
|
||||
```yaml
|
||||
command:
|
||||
- invokeai-configure
|
||||
- --yes
|
||||
```
|
||||
|
||||
Or install models:
|
||||
|
||||
```yaml
|
||||
command:
|
||||
- invokeai-model-install
|
||||
```
|
||||
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
||||
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html
|
||||
|
||||
@@ -196,6 +196,22 @@ tips to reduce the problem:
|
||||
=== "12GB VRAM GPU"
|
||||
|
||||
This should be sufficient to generate larger images up to about 1280x1280.
|
||||
|
||||
## Checkpoint Models Load Slowly or Use Too Much RAM
|
||||
|
||||
The difference between diffusers models (a folder containing multiple
|
||||
subfolders) and checkpoint models (a file ending with .safetensors or
|
||||
.ckpt) is that InvokeAI is able to load diffusers models into memory
|
||||
incrementally, while checkpoint models must be loaded all at
|
||||
once. With very large models, or systems with limited RAM, you may
|
||||
experience slowdowns and other memory-related issues when loading
|
||||
checkpoint models.
|
||||
|
||||
To solve this, go to the Model Manager tab (the cube), select the
|
||||
checkpoint model that's giving you trouble, and press the "Convert"
|
||||
button in the upper right of your browser window. This will conver the
|
||||
checkpoint into a diffusers model, after which loading should be
|
||||
faster and less memory-intensive.
|
||||
|
||||
## Memory Leak (Linux)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
set -eu
|
||||
|
||||
# Ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname "$0")
|
||||
scriptdir=$(dirname $(readlink -f "$0"))
|
||||
cd "$scriptdir"
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
from logging import Logger
|
||||
|
||||
import torch
|
||||
@@ -31,6 +32,8 @@ from invokeai.app.services.session_processor.session_processor_default import (
|
||||
)
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
|
||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
@@ -63,7 +66,12 @@ class ApiDependencies:
|
||||
invoker: Invoker
|
||||
|
||||
@staticmethod
|
||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
|
||||
def initialize(
|
||||
config: InvokeAIAppConfig,
|
||||
event_handler_id: int,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
logger: Logger = logger,
|
||||
) -> None:
|
||||
logger.info(f"InvokeAI version {__version__}")
|
||||
logger.info(f"Root directory = {str(config.root_path)}")
|
||||
|
||||
@@ -74,6 +82,7 @@ class ApiDependencies:
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
model_images_folder = config.models_path
|
||||
style_presets_folder = config.style_presets_path
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
@@ -84,7 +93,7 @@ class ApiDependencies:
|
||||
board_images = BoardImagesService()
|
||||
board_records = SqliteBoardRecordStorage(db=db)
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
events = FastAPIEventService(event_handler_id, loop=loop)
|
||||
bulk_download = BulkDownloadService()
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
@@ -109,6 +118,8 @@ class ApiDependencies:
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
|
||||
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
@@ -134,6 +145,8 @@ class ApiDependencies:
|
||||
workflow_records=workflow_records,
|
||||
tensors=tensors,
|
||||
conditioning=conditioning,
|
||||
style_preset_records=style_preset_records,
|
||||
style_preset_image_files=style_preset_image_files,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@@ -218,9 +218,8 @@ async def get_image_workflow(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.api_route(
|
||||
@images_router.get(
|
||||
"/i/{image_name}/full",
|
||||
methods=["GET", "HEAD"],
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@@ -231,6 +230,18 @@ async def get_image_workflow(
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
@images_router.head(
|
||||
"/i/{image_name}/full",
|
||||
operation_id="get_image_full_head",
|
||||
response_class=Response,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Return the full-resolution image",
|
||||
"content": {"image/png": {}},
|
||||
},
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
async def get_image_full(
|
||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||
) -> Response:
|
||||
@@ -242,6 +253,7 @@ async def get_image_full(
|
||||
content = f.read()
|
||||
response = Response(content, media_type="image/png")
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
response.headers["Content-Disposition"] = f'inline; filename="{image_name}"'
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional, Type
|
||||
|
||||
@@ -17,6 +19,7 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
@@ -31,6 +34,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
@@ -50,6 +54,13 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class CacheType(str, Enum):
|
||||
"""Cache type - one of vram or ram."""
|
||||
|
||||
RAM = "RAM"
|
||||
VRAM = "VRAM"
|
||||
|
||||
|
||||
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||
"""Add a cover image URL to a model configuration."""
|
||||
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
|
||||
@@ -797,3 +808,83 @@ async def get_starter_models() -> list[StarterModel]:
|
||||
model.dependencies = missing_deps
|
||||
|
||||
return starter_models
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/model_cache",
|
||||
operation_id="get_cache_size",
|
||||
response_model=float,
|
||||
summary="Get maximum size of model manager RAM or VRAM cache.",
|
||||
)
|
||||
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
|
||||
"""Return the current RAM or VRAM cache size setting (in GB)."""
|
||||
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||
value = 0.0
|
||||
if cache_type == CacheType.RAM:
|
||||
value = cache.max_cache_size
|
||||
elif cache_type == CacheType.VRAM:
|
||||
value = cache.max_vram_cache_size
|
||||
return value
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/model_cache",
|
||||
operation_id="set_cache_size",
|
||||
response_model=float,
|
||||
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
|
||||
)
|
||||
async def set_cache_size(
|
||||
value: float = Query(description="The new value for the maximum cache size"),
|
||||
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
|
||||
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
|
||||
) -> float:
|
||||
"""Set the current RAM or VRAM cache size setting (in GB). ."""
|
||||
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||
app_config = get_config()
|
||||
# Record initial state.
|
||||
vram_old = app_config.vram
|
||||
ram_old = app_config.ram
|
||||
|
||||
# Prepare target state.
|
||||
vram_new = vram_old
|
||||
ram_new = ram_old
|
||||
if cache_type == CacheType.RAM:
|
||||
ram_new = value
|
||||
elif cache_type == CacheType.VRAM:
|
||||
vram_new = value
|
||||
else:
|
||||
raise ValueError(f"Unexpected {cache_type=}.")
|
||||
|
||||
config_path = app_config.config_file_path
|
||||
new_config_path = config_path.with_suffix(".yaml.new")
|
||||
|
||||
try:
|
||||
# Try to apply the target state.
|
||||
cache.max_vram_cache_size = vram_new
|
||||
cache.max_cache_size = ram_new
|
||||
app_config.ram = ram_new
|
||||
app_config.vram = vram_new
|
||||
if persist:
|
||||
app_config.write_file(new_config_path)
|
||||
shutil.move(new_config_path, config_path)
|
||||
except Exception as e:
|
||||
# If there was a failure, restore the initial state.
|
||||
cache.max_cache_size = ram_old
|
||||
cache.max_vram_cache_size = vram_old
|
||||
app_config.ram = ram_old
|
||||
app_config.vram = vram_old
|
||||
|
||||
raise RuntimeError("Failed to update cache size") from e
|
||||
return value
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/stats",
|
||||
operation_id="get_stats",
|
||||
response_model=Optional[CacheStats],
|
||||
summary="Get model manager RAM cache performance statistics.",
|
||||
)
|
||||
async def get_stats() -> Optional[CacheStats]:
|
||||
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""
|
||||
|
||||
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
|
||||
|
||||
274
invokeai/app/api/routers/style_presets.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
import pydantic
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Path, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
InvalidPresetImportDataError,
|
||||
PresetData,
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetNotFoundError,
|
||||
StylePresetRecordWithImage,
|
||||
StylePresetWithoutId,
|
||||
UnsupportedFileTypeError,
|
||||
parse_presets_from_file,
|
||||
)
|
||||
|
||||
|
||||
class StylePresetFormData(BaseModel):
|
||||
name: str = Field(description="Preset name")
|
||||
positive_prompt: str = Field(description="Positive prompt")
|
||||
negative_prompt: str = Field(description="Negative prompt")
|
||||
type: PresetType = Field(description="Preset type")
|
||||
|
||||
|
||||
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/i/{style_preset_id}",
|
||||
operation_id="get_style_preset",
|
||||
responses={
|
||||
200: {"model": StylePresetRecordWithImage},
|
||||
},
|
||||
)
|
||||
async def get_style_preset(
|
||||
style_preset_id: str = Path(description="The style preset to get"),
|
||||
) -> StylePresetRecordWithImage:
|
||||
"""Gets a style preset"""
|
||||
try:
|
||||
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
|
||||
style_preset = ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
|
||||
return StylePresetRecordWithImage(image=image, **style_preset.model_dump())
|
||||
except StylePresetNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Style preset not found")
|
||||
|
||||
|
||||
@style_presets_router.patch(
|
||||
"/i/{style_preset_id}",
|
||||
operation_id="update_style_preset",
|
||||
responses={
|
||||
200: {"model": StylePresetRecordWithImage},
|
||||
},
|
||||
)
|
||||
async def update_style_preset(
|
||||
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
|
||||
style_preset_id: str = Path(description="The id of the style preset to update"),
|
||||
data: str = Form(description="The data of the style preset to update"),
|
||||
) -> StylePresetRecordWithImage:
|
||||
"""Updates a style preset"""
|
||||
if image is not None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.save(style_preset_id, pil_image)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
else:
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
|
||||
except StylePresetImageFileNotFoundException:
|
||||
pass
|
||||
|
||||
try:
|
||||
parsed_data = json.loads(data)
|
||||
validated_data = StylePresetFormData(**parsed_data)
|
||||
|
||||
name = validated_data.name
|
||||
type = validated_data.type
|
||||
positive_prompt = validated_data.positive_prompt
|
||||
negative_prompt = validated_data.negative_prompt
|
||||
|
||||
except pydantic.ValidationError:
|
||||
raise HTTPException(status_code=400, detail="Invalid preset data")
|
||||
|
||||
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
|
||||
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
|
||||
|
||||
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
|
||||
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
|
||||
style_preset_id=style_preset_id, changes=changes
|
||||
)
|
||||
return StylePresetRecordWithImage(image=style_preset_image, **style_preset.model_dump())
|
||||
|
||||
|
||||
@style_presets_router.delete(
|
||||
"/i/{style_preset_id}",
|
||||
operation_id="delete_style_preset",
|
||||
)
|
||||
async def delete_style_preset(
|
||||
style_preset_id: str = Path(description="The style preset to delete"),
|
||||
) -> None:
|
||||
"""Deletes a style preset"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
|
||||
except StylePresetImageFileNotFoundException:
|
||||
pass
|
||||
|
||||
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
|
||||
|
||||
|
||||
@style_presets_router.post(
|
||||
"/",
|
||||
operation_id="create_style_preset",
|
||||
responses={
|
||||
200: {"model": StylePresetRecordWithImage},
|
||||
},
|
||||
)
|
||||
async def create_style_preset(
|
||||
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
|
||||
data: str = Form(description="The data of the style preset to create"),
|
||||
) -> StylePresetRecordWithImage:
|
||||
"""Creates a style preset"""
|
||||
|
||||
try:
|
||||
parsed_data = json.loads(data)
|
||||
validated_data = StylePresetFormData(**parsed_data)
|
||||
|
||||
name = validated_data.name
|
||||
type = validated_data.type
|
||||
positive_prompt = validated_data.positive_prompt
|
||||
negative_prompt = validated_data.negative_prompt
|
||||
|
||||
except pydantic.ValidationError:
|
||||
raise HTTPException(status_code=400, detail="Invalid preset data")
|
||||
|
||||
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
|
||||
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data, type=type)
|
||||
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
|
||||
|
||||
if image is not None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.save(new_style_preset.id, pil_image)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(new_style_preset.id)
|
||||
return StylePresetRecordWithImage(image=preset_image, **new_style_preset.model_dump())
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/",
|
||||
operation_id="list_style_presets",
|
||||
responses={
|
||||
200: {"model": list[StylePresetRecordWithImage]},
|
||||
},
|
||||
)
|
||||
async def list_style_presets() -> list[StylePresetRecordWithImage]:
|
||||
"""Gets a page of style presets"""
|
||||
style_presets_with_image: list[StylePresetRecordWithImage] = []
|
||||
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many()
|
||||
for preset in style_presets:
|
||||
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(preset.id)
|
||||
style_preset_with_image = StylePresetRecordWithImage(image=image, **preset.model_dump())
|
||||
style_presets_with_image.append(style_preset_with_image)
|
||||
|
||||
return style_presets_with_image
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/i/{style_preset_id}/image",
|
||||
operation_id="get_style_preset_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The style preset image was fetched successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The style preset image could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def get_style_preset_image(
|
||||
style_preset_id: str = Path(description="The id of the style preset image to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image file that previews the model"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.style_preset_image_files.get_path(style_preset_id)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=style_preset_id + ".png",
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/export",
|
||||
operation_id="export_style_presets",
|
||||
responses={200: {"content": {"text/csv": {}}, "description": "A CSV file with the requested data."}},
|
||||
status_code=200,
|
||||
)
|
||||
async def export_style_presets():
|
||||
# Create an in-memory stream to store the CSV data
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Write the header
|
||||
writer.writerow(["name", "prompt", "negative_prompt"])
|
||||
|
||||
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many(type=PresetType.User)
|
||||
|
||||
for preset in style_presets:
|
||||
writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt])
|
||||
|
||||
csv_data = output.getvalue()
|
||||
output.close()
|
||||
|
||||
return Response(
|
||||
content=csv_data,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=prompt_templates.csv"},
|
||||
)
|
||||
|
||||
|
||||
@style_presets_router.post(
|
||||
"/import",
|
||||
operation_id="import_style_presets",
|
||||
)
|
||||
async def import_style_presets(file: UploadFile = File(description="The file to import")):
|
||||
try:
|
||||
style_presets = await parse_presets_from_file(file)
|
||||
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
|
||||
except InvalidPresetImportDataError as e:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except UnsupportedFileTypeError as e:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
@@ -30,6 +30,7 @@ from invokeai.app.api.routers import (
|
||||
images,
|
||||
model_manager,
|
||||
session_queue,
|
||||
style_presets,
|
||||
utilities,
|
||||
workflows,
|
||||
)
|
||||
@@ -55,11 +56,13 @@ mimetypes.add_type("text/css", ".css")
|
||||
torch_device_name = TorchDevice.get_torch_device_name()
|
||||
logger.info(f"Using torch device: {torch_device_name}")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Add startup event to load dependencies
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
|
||||
yield
|
||||
# Shut down threads
|
||||
ApiDependencies.shutdown()
|
||||
@@ -106,6 +109,7 @@ app.include_router(board_images.board_images_router, prefix="/api")
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
app.include_router(style_presets.style_presets_router, prefix="/api")
|
||||
|
||||
app.openapi = get_openapi_func(app)
|
||||
|
||||
@@ -184,8 +188,6 @@ def invoke_api() -> None:
|
||||
|
||||
check_cudnn(logger)
|
||||
|
||||
# Start our own event loop for eventing usage
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=app_config.host,
|
||||
|
||||
@@ -19,8 +19,8 @@ from invokeai.app.invocations.model import CLIPField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
|
||||
@@ -21,6 +21,8 @@ from controlnet_aux import (
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -44,13 +46,12 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
@@ -592,7 +593,14 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
return color_map
|
||||
|
||||
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
|
||||
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "LiheYoung/depth-anything-large-hf",
|
||||
"base": "LiheYoung/depth-anything-base-hf",
|
||||
"small": "LiheYoung/depth-anything-small-hf",
|
||||
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -600,28 +608,33 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
||||
title="Depth Anything Processor",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.1.2",
|
||||
version="1.1.3",
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
|
||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||
default="small", description="The size of the depth model to use"
|
||||
default="small_v2", description="The size of the depth model to use"
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def loader(model_path: Path):
|
||||
return DepthAnythingDetector.load_model(
|
||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
def load_depth_anything(model_path: Path):
|
||||
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
||||
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||
) as model:
|
||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||
) as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
||||
# Resizing to user target specified size
|
||||
new_height = int(image.size[1] * (self.resolution / image.size[0]))
|
||||
depth_map = depth_map.resize((self.resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
|
||||
|
||||
@invocation(
|
||||
|
||||
@@ -36,9 +36,9 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
@@ -185,7 +185,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
input=Input.Connection,
|
||||
ui_order=8,
|
||||
)
|
||||
|
||||
@@ -40,14 +40,18 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
FluxVAEModel = "FluxVAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
@@ -125,13 +129,17 @@ class FieldDescriptions:
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
@@ -173,7 +181,7 @@ class FieldDescriptions:
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
denoise_mask = "A mask of the region to apply the denoising process to."
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
@@ -231,6 +239,12 @@ class ColorField(BaseModel):
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class FluxConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
|
||||
292
invokeai/app/invocations/flux_denoise.py
Normal file
@@ -0,0 +1,292 @@
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
clip_timestep_schedule,
|
||||
generate_img_ids,
|
||||
get_noise,
|
||||
get_schedule,
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.peft.peft_patcher import PeftPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_denoise",
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
input=Input.Connection,
|
||||
)
|
||||
denoising_start: float = InputField(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
|
||||
)
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
|
||||
|
||||
# Prepare input noise.
|
||||
noise = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=image_seq_len,
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||
timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end)
|
||||
|
||||
# Prepare input latent image.
|
||||
if init_latents is not None:
|
||||
# If init_latents is provided, we are doing image-to-image.
|
||||
|
||||
if is_schnell:
|
||||
context.logger.warning(
|
||||
"Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely "
|
||||
"to be poor. Consider using a FLUX dev model instead."
|
||||
)
|
||||
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
|
||||
x = noise
|
||||
|
||||
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
||||
# denoising steps.
|
||||
if len(timesteps) <= 1:
|
||||
return x
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, h, w = x.shape
|
||||
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||
assert image_seq_len == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
inpaint_extension = InpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
with (
|
||||
transformer_info.model_on_device() as (cached_weights, transformer),
|
||||
# Apply the LoRA after transformer has been moved to its target device for faster patching.
|
||||
PeftPatcher.apply_peft_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix="",
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
inpaint_extension=inpaint_extension,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
|
||||
- Loads the mask
|
||||
- Resizes if necessary
|
||||
- Casts to same device/dtype as latents
|
||||
- Expands mask to the same shape as latents so that they line up after 'packing'
|
||||
|
||||
Args:
|
||||
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
||||
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
|
||||
device, and dtype for the inpaint mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor | None: Inpaint mask.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
|
||||
# `latents`.
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[], None]:
|
||||
def step_callback() -> None:
|
||||
if context.util.is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# TODO: Make this look like the image before re-enabling
|
||||
# latent_image = unpack(img.float(), self.height, self.width)
|
||||
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
|
||||
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
|
||||
|
||||
# # Create a new tensor of the required shape [255, 255, 3]
|
||||
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
|
||||
|
||||
# # Convert to a NumPy array and then to a PIL Image
|
||||
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
|
||||
|
||||
# (width, height) = image.size
|
||||
# width *= 8
|
||||
# height *= 8
|
||||
|
||||
# dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
# # TODO: move this whole function to invocation context to properly reference these variables
|
||||
# context._services.events.emit_invocation_denoise_progress(
|
||||
# context._data.queue_item,
|
||||
# context._data.invocation,
|
||||
# state,
|
||||
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
# )
|
||||
|
||||
return step_callback
|
||||
53
invokeai/app/invocations/flux_lora_loader.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_lora_loader_output")
|
||||
class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""FLUX LoRA Loader Output"""
|
||||
|
||||
transformer: TransformerField = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_lora_loader",
|
||||
title="FLUX LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
|
||||
transformer = self.transformer.model_copy(deep=True)
|
||||
transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return FluxLoRALoaderOutput(transformer=transformer)
|
||||
92
invokeai/app/invocations/flux_text_encoder.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_encoder",
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a flux image."""
|
||||
|
||||
clip: CLIPField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_encoder: T5EncoderField = InputField(
|
||||
title="T5Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
||||
t5_embeddings = self._t5_encode(context)
|
||||
clip_embeddings = self._clip_encode(context)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, T5Tokenizer)
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info as clip_text_encoder,
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return pooled_prompt_embeds
|
||||
60
invokeai/app/invocations/flux_vae_decode.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_vae_decode",
|
||||
title="FLUX Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i", "flux"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
return img_pil
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
67
invokeai/app/invocations/flux_vae_encode.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_vae_encode",
|
||||
title="FLUX Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l", "flux"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(ryand): Expose seed parameter at the invocation level.
|
||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||
# should be used for VAE encode sampling.
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
image_tensor = image_tensor.to(
|
||||
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
@@ -126,7 +126,7 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
||||
title="Tensor Mask to Image",
|
||||
tags=["mask"],
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
)
|
||||
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Convert a mask tensor to an image."""
|
||||
@@ -135,6 +135,11 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.tensors.load(self.mask.tensor_name)
|
||||
|
||||
# Squeeze the channel dimension if it exists.
|
||||
if mask.dim() == 3:
|
||||
mask = mask.squeeze(0)
|
||||
|
||||
# Ensure that the mask is binary.
|
||||
if mask.dtype != torch.bool:
|
||||
mask = mask > 0.5
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -13,7 +13,14 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
class ModelIdentifierField(BaseModel):
|
||||
@@ -60,6 +67,16 @@ class CLIPField(BaseModel):
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
@@ -122,6 +139,78 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
return ModelIdentifierOutput(model=self.model)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
|
||||
@@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
@@ -414,6 +415,17 @@ class MaskOutput(BaseInvocationOutput):
|
||||
height: int = OutputField(description="The height of the mask in pixels.")
|
||||
|
||||
|
||||
@invocation_output("flux_conditioning_output")
|
||||
class FluxConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
|
||||
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
MultiDiffusionPipeline,
|
||||
|
||||
@@ -91,6 +91,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
db_dir: Path to InvokeAI databases directory.
|
||||
outputs_dir: Path to directory for outputs.
|
||||
custom_nodes_dir: Path to directory for custom nodes.
|
||||
style_presets_dir: Path to directory for style presets.
|
||||
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
|
||||
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
|
||||
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
|
||||
@@ -153,6 +154,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
||||
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
|
||||
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
|
||||
|
||||
# LOGGING
|
||||
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
|
||||
@@ -300,6 +302,11 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
"""Path to the models directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def style_presets_path(self) -> Path:
|
||||
"""Path to the style presets directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.style_presets_dir)
|
||||
|
||||
@property
|
||||
def convert_cache_path(self) -> Path:
|
||||
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
||||
|
||||
@@ -1,46 +1,44 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
EventBase,
|
||||
)
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self._queue = Queue[EventBase | None]()
|
||||
self._queue = asyncio.Queue[EventBase | None]()
|
||||
self._stop_event = threading.Event()
|
||||
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||
self._loop = loop
|
||||
|
||||
# We need to store a reference to the task so it doesn't get GC'd
|
||||
# See: https://docs.python.org/3/library/asyncio-task.html#creating-tasks
|
||||
self._background_tasks: set[asyncio.Task[None]] = set()
|
||||
task = self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.remove)
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._stop_event.set()
|
||||
self._queue.put(None)
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
|
||||
|
||||
def dispatch(self, event: EventBase) -> None:
|
||||
self._queue.put(event)
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
||||
|
||||
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self._queue.get(block=False)
|
||||
event = await self._queue.get()
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
# Leave the payloads as live pydantic models
|
||||
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
@@ -61,6 +63,8 @@ class InvocationServices:
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||
style_preset_records: "StylePresetRecordsStorageBase",
|
||||
style_preset_image_files: "StylePresetImageFileStorageBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
@@ -85,3 +89,5 @@ class InvocationServices:
|
||||
self.workflow_records = workflow_records
|
||||
self.tensors = tensors
|
||||
self.conditioning = conditioning
|
||||
self.style_preset_records = style_preset_records
|
||||
self.style_preset_image_files = style_preset_image_files
|
||||
|
||||
@@ -103,7 +103,7 @@ class HFModelSource(StringLikeSource):
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
base += f"::{self.subfolder.as_posix()}"
|
||||
return base
|
||||
|
||||
|
||||
|
||||
@@ -783,8 +783,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||
if subfolder:
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||
path_to_add = Path(f"{top}_{subfolder}")
|
||||
path_to_remove = top / subfolder # sdxl-turbo/vae/
|
||||
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
|
||||
path_to_add = Path(f"{top}_{subfolder_rename}")
|
||||
else:
|
||||
path_to_remove = Path(".")
|
||||
path_to_add = Path(".")
|
||||
|
||||
@@ -77,6 +77,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
type: Optional[ModelType] = Field(description="Type of model", default=None)
|
||||
key: Optional[str] = Field(description="Database ID for this model", default=None)
|
||||
hash: Optional[str] = Field(description="hash of model file", default=None)
|
||||
format: Optional[str] = Field(description="format of model file", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
|
||||
@@ -16,6 +16,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -49,6 +50,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_12(app_config=config))
|
||||
migrator.register_migration(build_migration_13())
|
||||
migrator.register_migration(build_migration_14())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration14Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_style_presets(cursor)
|
||||
|
||||
def _create_style_presets(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create the table used to store style presets."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS style_presets (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
preset_data TEXT NOT NULL,
|
||||
type TEXT NOT NULL DEFAULT "user",
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS style_presets
|
||||
AFTER UPDATE
|
||||
ON style_presets FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE style_presets SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_style_presets_name ON style_presets(name);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_14() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 13 to 14..
|
||||
|
||||
This migration does the following:
|
||||
- Create the table used to store style presets.
|
||||
"""
|
||||
migration_14 = Migration(
|
||||
from_version=13,
|
||||
to_version=14,
|
||||
callback=Migration14Callback(),
|
||||
)
|
||||
|
||||
return migration_14
|
||||
|
After Width: | Height: | Size: 98 KiB |
|
After Width: | Height: | Size: 138 KiB |
|
After Width: | Height: | Size: 122 KiB |
|
After Width: | Height: | Size: 123 KiB |
|
After Width: | Height: | Size: 160 KiB |
|
After Width: | Height: | Size: 146 KiB |
|
After Width: | Height: | Size: 119 KiB |
|
After Width: | Height: | Size: 117 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 46 KiB |
|
After Width: | Height: | Size: 79 KiB |
|
After Width: | Height: | Size: 156 KiB |
|
After Width: | Height: | Size: 141 KiB |
|
After Width: | Height: | Size: 96 KiB |
|
After Width: | Height: | Size: 91 KiB |
|
After Width: | Height: | Size: 88 KiB |
|
After Width: | Height: | Size: 107 KiB |
|
After Width: | Height: | Size: 132 KiB |
@@ -0,0 +1,33 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
|
||||
class StylePresetImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, style_preset_id: str) -> PILImageType:
|
||||
"""Retrieves a style preset image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, style_preset_id: str) -> Path:
|
||||
"""Gets the internal path to a style preset image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, style_preset_id: str) -> str | None:
|
||||
"""Gets the URL to fetch a style preset image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, style_preset_id: str, image: PILImageType) -> None:
|
||||
"""Saves a style preset image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
"""Deletes a style preset image."""
|
||||
pass
|
||||
@@ -0,0 +1,19 @@
|
||||
class StylePresetImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message: str = "Style preset image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class StylePresetImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message: str = "Style preset image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class StylePresetImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message: str = "Style preset image file not deleted"):
|
||||
super().__init__(message)
|
||||
@@ -0,0 +1,88 @@
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_common import (
|
||||
StylePresetImageFileDeleteException,
|
||||
StylePresetImageFileNotFoundException,
|
||||
StylePresetImageFileSaveException,
|
||||
)
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import PresetType
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.app.util.thumbnails import make_thumbnail
|
||||
|
||||
|
||||
class StylePresetImageFileStorageDisk(StylePresetImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
def __init__(self, style_preset_images_folder: Path):
|
||||
self._style_preset_images_folder = style_preset_images_folder
|
||||
self._validate_storage_folders()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def get(self, style_preset_id: str) -> PILImageType:
|
||||
try:
|
||||
path = self.get_path(style_preset_id)
|
||||
|
||||
return Image.open(path)
|
||||
except FileNotFoundError as e:
|
||||
raise StylePresetImageFileNotFoundException from e
|
||||
|
||||
def save(self, style_preset_id: str, image: PILImageType) -> None:
|
||||
try:
|
||||
self._validate_storage_folders()
|
||||
image_path = self._style_preset_images_folder / (style_preset_id + ".webp")
|
||||
thumbnail = make_thumbnail(image, 256)
|
||||
thumbnail.save(image_path, format="webp")
|
||||
|
||||
except Exception as e:
|
||||
raise StylePresetImageFileSaveException from e
|
||||
|
||||
def get_path(self, style_preset_id: str) -> Path:
|
||||
style_preset = self._invoker.services.style_preset_records.get(style_preset_id)
|
||||
if style_preset.type is PresetType.Default:
|
||||
default_images_dir = Path(__file__).parent / Path("default_style_preset_images")
|
||||
path = default_images_dir / (style_preset.name + ".png")
|
||||
else:
|
||||
path = self._style_preset_images_folder / (style_preset_id + ".webp")
|
||||
|
||||
return path
|
||||
|
||||
def get_url(self, style_preset_id: str) -> str | None:
|
||||
path = self.get_path(style_preset_id)
|
||||
if not self._validate_path(path):
|
||||
return
|
||||
|
||||
url = self._invoker.services.urls.get_style_preset_image_url(style_preset_id)
|
||||
|
||||
# The image URL never changes, so we must add random query string to it to prevent caching
|
||||
url += f"?{uuid_string()}"
|
||||
|
||||
return url
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
try:
|
||||
path = self.get_path(style_preset_id)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise StylePresetImageFileNotFoundException
|
||||
|
||||
path.unlink()
|
||||
|
||||
except StylePresetImageFileNotFoundException as e:
|
||||
raise StylePresetImageFileNotFoundException from e
|
||||
except Exception as e:
|
||||
raise StylePresetImageFileDeleteException from e
|
||||
|
||||
def _validate_path(self, path: Path) -> bool:
|
||||
"""Validates the path given for an image."""
|
||||
return path.exists()
|
||||
|
||||
def _validate_storage_folders(self) -> None:
|
||||
"""Checks if the required folders exist and create them if they don't"""
|
||||
self._style_preset_images_folder.mkdir(parents=True, exist_ok=True)
|
||||
@@ -0,0 +1,146 @@
|
||||
[
|
||||
{
|
||||
"name": "Photography (General)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}. photography. f/2.8 macro photo, bokeh, photorealism",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Studio Lighting)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}, photography. f/8 photo. centered subject, studio lighting.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Landscape)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}, landscape photograph, f/12, lifelike, highly detailed.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Portrait)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}. photography. portraiture. catch light in eyes. one flash. rembrandt lighting. Soft box. dark shadows. High contrast. 80mm lens. F2.8.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Black and White)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} photography. natural light. 80mm lens. F1.4. strong contrast, hard light. dark contrast. blurred background. black and white",
|
||||
"negative_prompt": "painting, digital art. sketch, colour+"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Architectural Visualization",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}. architectural photography, f/12, luxury, aesthetically pleasing form and function.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Fantasy)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "concept artwork of a {prompt}. (digital painterly art style)++, mythological, (textured 2d dry media brushpack)++, glazed brushstrokes, otherworldly. painting+, illustration+",
|
||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Sci-Fi)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "(concept art)++, {prompt}, (sleek futurism)++, (textured 2d dry media)++, metallic highlights, digital painting style",
|
||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Character)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "(character concept art)++, stylized painterly digital painting of {prompt}, (painterly, impasto. Dry brush.)++",
|
||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Painterly)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} oil painting. high contrast. impasto. sfumato. chiaroscuro. Palette knife.",
|
||||
"negative_prompt": "photo. smooth. border. frame"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Environment Art",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} environment artwork, hyper-realistic digital painting style with cinematic composition, atmospheric, depth and detail, voluminous. textured dry brush 2d media",
|
||||
"negative_prompt": "photo, distorted, blurry, out of focus. sketch."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Interior Design (Visualization)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} interior design photo, gentle shadows, light mid-tones, dimension, mix of smooth and textured surfaces, focus on negative space and clean lines, focus",
|
||||
"negative_prompt": "photo, distorted. sketch."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Product Rendering",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} high quality product photography, 3d rendering with key lighting, shallow depth of field, simple plain background, studio lighting.",
|
||||
"negative_prompt": "blurry, sketch, messy, dirty. unfinished."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Sketch",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} black and white pencil drawing, off-center composition, cross-hatching for shadows, bold strokes, textured paper. sketch+++",
|
||||
"negative_prompt": "blurry, photo, painting, color. messy, dirty. unfinished. frame, borders."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Line Art",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} Line art. bold outline. simplistic. white background. 2d",
|
||||
"negative_prompt": "photo. digital art. greyscale. solid black. painting"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Anime",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} anime++, bold outline, cel-shaded coloring, shounen, seinen",
|
||||
"negative_prompt": "(photo)+++. greyscale. solid black. painting"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Illustration",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} illustration, bold linework, illustrative details, vector art style, flat coloring",
|
||||
"negative_prompt": "(photo)+++. greyscale. painting, black and white."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Vehicles",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "A weird futuristic normal auto, {prompt} elegant design, nice color, nice wheels",
|
||||
"negative_prompt": "sketch. digital art. greyscale. painting"
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetRecordDTO,
|
||||
StylePresetWithoutId,
|
||||
)
|
||||
|
||||
|
||||
class StylePresetRecordsStorageBase(ABC):
|
||||
"""Base class for style preset storage services."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Get style preset by id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
"""Creates a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
"""Creates many style presets."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
"""Updates a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
"""Deletes a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
"""Gets many workflows."""
|
||||
pass
|
||||
@@ -0,0 +1,139 @@
|
||||
import codecs
|
||||
import csv
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import pydantic
|
||||
from fastapi import UploadFile
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
|
||||
class StylePresetNotFoundError(Exception):
|
||||
"""Raised when a style preset is not found"""
|
||||
|
||||
|
||||
class PresetData(BaseModel, extra="forbid"):
|
||||
positive_prompt: str = Field(description="Positive prompt")
|
||||
negative_prompt: str = Field(description="Negative prompt")
|
||||
|
||||
|
||||
PresetDataValidator = TypeAdapter(PresetData)
|
||||
|
||||
|
||||
class PresetType(str, Enum, metaclass=MetaEnum):
|
||||
User = "user"
|
||||
Default = "default"
|
||||
Project = "project"
|
||||
|
||||
|
||||
class StylePresetChanges(BaseModel, extra="forbid"):
|
||||
name: Optional[str] = Field(default=None, description="The style preset's new name.")
|
||||
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
|
||||
type: Optional[PresetType] = Field(description="The updated type of the style preset")
|
||||
|
||||
|
||||
class StylePresetWithoutId(BaseModel):
|
||||
name: str = Field(description="The name of the style preset.")
|
||||
preset_data: PresetData = Field(description="The preset data")
|
||||
type: PresetType = Field(description="The type of style preset")
|
||||
|
||||
|
||||
class StylePresetRecordDTO(StylePresetWithoutId):
|
||||
id: str = Field(description="The style preset ID.")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO":
|
||||
data["preset_data"] = PresetDataValidator.validate_json(data.get("preset_data", ""))
|
||||
return StylePresetRecordDTOValidator.validate_python(data)
|
||||
|
||||
|
||||
StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
|
||||
|
||||
|
||||
class StylePresetRecordWithImage(StylePresetRecordDTO):
|
||||
image: Optional[str] = Field(description="The path for image")
|
||||
|
||||
|
||||
class StylePresetImportRow(BaseModel):
|
||||
name: str = Field(min_length=1, description="The name of the preset.")
|
||||
positive_prompt: str = Field(
|
||||
default="",
|
||||
description="The positive prompt for the preset.",
|
||||
validation_alias=AliasChoices("positive_prompt", "prompt"),
|
||||
)
|
||||
negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
|
||||
|
||||
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
|
||||
|
||||
|
||||
StylePresetImportList = list[StylePresetImportRow]
|
||||
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(ValueError):
|
||||
"""Raised when an unsupported file type is encountered"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidPresetImportDataError(ValueError):
|
||||
"""Raised when invalid preset import data is encountered"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId]:
|
||||
"""Parses style presets from a file. The file must be a CSV or JSON file.
|
||||
|
||||
If CSV, the file must have the following columns:
|
||||
- name
|
||||
- prompt (or positive_prompt)
|
||||
- negative_prompt
|
||||
|
||||
If JSON, the file must be a list of objects with the following keys:
|
||||
- name
|
||||
- prompt (or positive_prompt)
|
||||
- negative_prompt
|
||||
|
||||
Args:
|
||||
file (UploadFile): The file to parse.
|
||||
|
||||
Returns:
|
||||
list[StylePresetWithoutId]: The parsed style presets.
|
||||
|
||||
Raises:
|
||||
UnsupportedFileTypeError: If the file type is not supported.
|
||||
InvalidPresetImportDataError: If the data in the file is invalid.
|
||||
"""
|
||||
if file.content_type not in ["text/csv", "application/json"]:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if file.content_type == "text/csv":
|
||||
csv_reader = csv.DictReader(codecs.iterdecode(file.file, "utf-8"))
|
||||
data = list(csv_reader)
|
||||
else: # file.content_type == "application/json":
|
||||
json_data = await file.read()
|
||||
data = json.loads(json_data)
|
||||
|
||||
try:
|
||||
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
|
||||
|
||||
style_presets: list[StylePresetWithoutId] = []
|
||||
|
||||
for imported in imported_presets:
|
||||
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
|
||||
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
|
||||
style_presets.append(style_preset)
|
||||
except pydantic.ValidationError as e:
|
||||
if file.content_type == "text/csv":
|
||||
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
|
||||
else: # file.content_type == "application/json":
|
||||
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
|
||||
raise InvalidPresetImportDataError(msg) from e
|
||||
finally:
|
||||
file.file.close()
|
||||
|
||||
return style_presets
|
||||
@@ -0,0 +1,215 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetNotFoundError,
|
||||
StylePresetRecordDTO,
|
||||
StylePresetWithoutId,
|
||||
)
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._sync_default_style_presets()
|
||||
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Gets a style preset by ID."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = self._cursor.fetchone()
|
||||
if row is None:
|
||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
||||
return StylePresetRecordDTO.from_dict(dict(row))
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
style_preset_id = uuid_string()
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
id,
|
||||
name,
|
||||
preset_data,
|
||||
type
|
||||
)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
style_preset_id,
|
||||
style_preset.name,
|
||||
style_preset.preset_data.model_dump_json(),
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
style_preset_ids = []
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for style_preset in style_presets:
|
||||
style_preset_id = uuid_string()
|
||||
style_preset_ids.append(style_preset_id)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
id,
|
||||
name,
|
||||
preset_data,
|
||||
type
|
||||
)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
style_preset_id,
|
||||
style_preset.name,
|
||||
style_preset.preset_data.model_dump_json(),
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return None
|
||||
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Change the name of a style preset
|
||||
if changes.name is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE style_presets
|
||||
SET name = ?
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(changes.name, style_preset_id),
|
||||
)
|
||||
|
||||
# Change the preset data for a style preset
|
||||
if changes.preset_data is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE style_presets
|
||||
SET preset_data = ?
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(changes.preset_data.model_dump_json(), style_preset_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE from style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return None
|
||||
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
|
||||
if type is not None:
|
||||
self._cursor.execute(main_query, (type,))
|
||||
else:
|
||||
self._cursor.execute(main_query)
|
||||
|
||||
rows = self._cursor.fetchall()
|
||||
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
|
||||
|
||||
return style_presets
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _sync_default_style_presets(self) -> None:
|
||||
"""Syncs default style presets to the database. Internal use only."""
|
||||
|
||||
# First delete all existing default style presets
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM style_presets
|
||||
WHERE type = "default";
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
# Next, parse and create the default style presets
|
||||
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
|
||||
presets = json.load(file)
|
||||
for preset in presets:
|
||||
style_preset = StylePresetWithoutId.model_validate(preset)
|
||||
self.create(style_preset)
|
||||
@@ -13,3 +13,8 @@ class UrlServiceBase(ABC):
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
"""Gets the URL for a model image"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_style_preset_image_url(self, style_preset_id: str) -> str:
|
||||
"""Gets the URL for a style preset image"""
|
||||
pass
|
||||
|
||||
@@ -19,3 +19,6 @@ class LocalUrlService(UrlServiceBase):
|
||||
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
||||
|
||||
def get_style_preset_image_url(self, style_preset_id: str) -> str:
|
||||
return f"{self._base_url}/style_presets/i/{style_preset_id}/image"
|
||||
|
||||
@@ -0,0 +1,407 @@
|
||||
{
|
||||
"name": "FLUX Image to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple image-to-image workflow using a FLUX dev model. ",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "image2image, flux, image-to-image",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend using FLUX dev models for image-to-image workflows. The image-to-image performance with FLUX schnell models is poor.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "clip_embed_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "vae_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"fieldName": "denoising_start"
|
||||
},
|
||||
{
|
||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"fieldName": "num_steps"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"type": "flux_vae_encode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"image": {
|
||||
"name": "image",
|
||||
"label": "",
|
||||
"value": {
|
||||
"image_name": "8a5c62aa-9335-45d2-9c71-89af9fc1f8d4.png"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 732.7680166609682,
|
||||
"y": -24.37398171806909
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"type": "flux_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"denoise_mask": {
|
||||
"name": "denoise_mask",
|
||||
"label": ""
|
||||
},
|
||||
"denoising_start": {
|
||||
"name": "denoising_start",
|
||||
"label": "",
|
||||
"value": 0.04
|
||||
},
|
||||
"denoising_end": {
|
||||
"name": "denoising_end",
|
||||
"label": "",
|
||||
"value": 1
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1182.8836633018684,
|
||||
"y": -251.38882958913183
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "flux_vae_decode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1575.5797431839133,
|
||||
"y": -209.00150975507415
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.4",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": "Model (dev variant recommended for Image-to-Image)"
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "fa23a584-b623-415d-832a-21b5098ff1a1",
|
||||
"hash": "blake3:17c19f0ef941c3b7609a9c94a659ca5364de0be364a91d4179f0e39ba17c3b70",
|
||||
"name": "clip-vit-large-patch14",
|
||||
"base": "any",
|
||||
"type": "clip_embed"
|
||||
}
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "74fc82ba-c0a8-479d-a890-2126f82da758",
|
||||
"hash": "blake3:ce21cb76364aa6e2421311cf4a4b5eb052a76c4f1cd207b50703d8978198a068",
|
||||
"name": "FLUX.1-schnell_ae",
|
||||
"base": "flux",
|
||||
"type": "vae"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 328.1809894659957,
|
||||
"y": -90.2241133566946
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "flux_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"clip": {
|
||||
"name": "clip",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"t5_max_seq_len": {
|
||||
"name": "t5_max_seq_len",
|
||||
"label": "T5 Max Seq Len",
|
||||
"value": 256
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": "a cat wearing a birthday hat"
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 745.8823365057267,
|
||||
"y": -299.60249175851914
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 725.834098928012,
|
||||
"y": 496.2710031089931
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bheight-ace0258f-67d7-4eee-a218-6fff27065214height",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "height",
|
||||
"targetHandle": "height"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bwidth-ace0258f-67d7-4eee-a218-6fff27065214width",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "width",
|
||||
"targetHandle": "width"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912blatents-ace0258f-67d7-4eee-a218-6fff27065214latents",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-2981a67c-480f-4237-9384-26b68dbf912bvae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-ace0258f-67d7-4eee-a218-6fff27065214latents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
|
||||
"type": "default",
|
||||
"source": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-ace0258f-67d7-4eee-a218-6fff27065214seed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-ace0258f-67d7-4eee-a218-6fff27065214transformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-ace0258f-67d7-4eee-a218-6fff27065214positive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,326 @@
|
||||
{
|
||||
"name": "FLUX Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models.",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "text2image, flux",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "clip_embed_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "vae_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"fieldName": "num_steps"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"type": "flux_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"denoise_mask": {
|
||||
"name": "denoise_mask",
|
||||
"label": ""
|
||||
},
|
||||
"denoising_start": {
|
||||
"name": "denoising_start",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"denoising_end": {
|
||||
"name": "denoising_end",
|
||||
"label": "",
|
||||
"value": 1
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1186.1868226120378,
|
||||
"y": -214.9459927686657
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "flux_vae_decode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1575.5797431839133,
|
||||
"y": -209.00150975507415
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.4",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": ""
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 381.1882713063478,
|
||||
"y": -95.89663532854017
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "flux_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"clip": {
|
||||
"name": "clip",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"t5_max_seq_len": {
|
||||
"name": "t5_max_seq_len",
|
||||
"label": "T5 Max Seq Len",
|
||||
"value": 256
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": "a cat"
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 778.4899149328337,
|
||||
"y": -100.36469216659502
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 800.9667463219505,
|
||||
"y": 285.8297267547506
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-4fe24f07-f906-4f55-ab2c-9beee56ef5bdtransformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-4fe24f07-f906-4f55-ab2c-9beee56ef5bdpositive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-4fe24f07-f906-4f55-ab2c-9beee56ef5bdseed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4fe24f07-f906-4f55-ab2c-9beee56ef5bdlatents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
|
||||
"type": "default",
|
||||
"source": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -81,7 +81,7 @@ def get_openapi_func(
|
||||
# Add the output map to the schema
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||
"type": "object",
|
||||
"properties": invocation_output_map_properties,
|
||||
"properties": dict(sorted(invocation_output_map_properties.items())),
|
||||
"required": invocation_output_map_required,
|
||||
}
|
||||
|
||||
|
||||
45
invokeai/backend/flux/denoise.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[], None],
|
||||
guidance: float,
|
||||
inpaint_extension: InpaintExtension | None,
|
||||
):
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
if inpaint_extension is not None:
|
||||
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
||||
|
||||
step_callback()
|
||||
|
||||
return img
|
||||
35
invokeai/backend/flux/inpaint_extension.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
|
||||
|
||||
class InpaintExtension:
|
||||
"""A class for managing inpainting with FLUX."""
|
||||
|
||||
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
|
||||
"""Initialize InpaintExtension.
|
||||
|
||||
Args:
|
||||
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format.
|
||||
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
|
||||
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
|
||||
inpainted region with the background. In 'packed' format.
|
||||
noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format.
|
||||
"""
|
||||
assert init_latents.shape == inpaint_mask.shape == noise.shape
|
||||
self._init_latents = init_latents
|
||||
self._inpaint_mask = inpaint_mask
|
||||
self._noise = noise
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
self, intermediate_latents: torch.Tensor, timestep: float
|
||||
) -> torch.Tensor:
|
||||
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
|
||||
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
|
||||
trajectory.
|
||||
|
||||
This function should be called after each denoising step.
|
||||
"""
|
||||
# Noise the init latents for the current timestep.
|
||||
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
|
||||
|
||||
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)
|
||||
32
invokeai/backend/flux/math.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
117
invokeai/backend/flux/model.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
324
invokeai/backend/flux/modules/autoencoder.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoEncoderParams:
|
||||
resolution: int
|
||||
in_channels: int
|
||||
ch: int
|
||||
out_ch: int
|
||||
ch_mult: list[int]
|
||||
num_res_blocks: int
|
||||
z_channels: int
|
||||
scale_factor: float
|
||||
shift_factor: float
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def attention(self, h_: Tensor) -> Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
||||
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
||||
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
||||
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.proj_out(self.attention(x))
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
# downsampling
|
||||
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
block_in = self.ch
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
# Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we
|
||||
# have to use torch.randn(...) instead.
|
||||
return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device)
|
||||
else:
|
||||
return mean
|
||||
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, params: AutoEncoderParams):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
out_ch=params.out_ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.reg = DiagonalGaussian()
|
||||
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||
"""Run VAE encoding on input tensor x.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width).
|
||||
sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean.
|
||||
Defaults to True.
|
||||
generator (torch.Generator | None, optional): Optional random number generator for reproducibility.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
|
||||
"""
|
||||
|
||||
z = self.reg(self.encoder(x), sample=sample, generator=generator)
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
def decode(self, z: Tensor) -> Tensor:
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.decode(self.encode(x))
|
||||
33
invokeai/backend/flux/modules/conditioner.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from torch import Tensor, nn
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
class HFEncoder(nn.Module):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.is_clip = is_clip
|
||||
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
||||
self.tokenizer = tokenizer
|
||||
self.hf_module = encoder
|
||||
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||
|
||||
def forward(self, text: list[str]) -> Tensor:
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
outputs = self.hf_module(
|
||||
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
||||
attention_mask=None,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return outputs[self.output_key]
|
||||
253
invokeai/backend/flux/modules/layers.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.math import attention, rope
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
135
invokeai/backend/flux/sampling_utils.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# estimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int:
|
||||
"""Find the last index in timesteps that is >= val.
|
||||
|
||||
We use epsilon-close equality to avoid potential floating point errors.
|
||||
"""
|
||||
idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1
|
||||
assert idx >= 0
|
||||
return idx
|
||||
|
||||
|
||||
def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]:
|
||||
"""Clip the timestep schedule to the denoising range.
|
||||
|
||||
Args:
|
||||
timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0].
|
||||
denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2
|
||||
would mean that the denoising process start at the last timestep in the schedule >= 0.8.
|
||||
denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would
|
||||
mean that the denoising process end at the last timestep in the schedule >= 0.2.
|
||||
|
||||
Returns:
|
||||
list[float]: The clipped timestep schedule.
|
||||
"""
|
||||
assert 0.0 <= denoising_start <= 1.0
|
||||
assert 0.0 <= denoising_end <= 1.0
|
||||
assert denoising_start <= denoising_end
|
||||
|
||||
t_start_val = 1.0 - denoising_start
|
||||
t_end_val = 1.0 - denoising_end
|
||||
|
||||
t_start_idx = _find_last_index_ge_val(timesteps, t_start_val)
|
||||
t_end_idx = _find_last_index_ge_val(timesteps, t_end_val)
|
||||
|
||||
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
|
||||
|
||||
return clipped_timesteps
|
||||
|
||||
|
||||
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""Unpack flat array of patch embeddings to latent image."""
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
|
||||
|
||||
def pack(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack latent image to flattented array of patch embeddings."""
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
|
||||
|
||||
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids.
|
||||
|
||||
Args:
|
||||
h (int): Height of image in latent space.
|
||||
w (int): Width of image in latent space.
|
||||
batch_size (int): Batch size.
|
||||
device (torch.device): Device.
|
||||
dtype (torch.dtype): dtype.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image position ids.
|
||||
"""
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
return img_ids
|
||||
71
invokeai/backend/flux/util.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Literal
|
||||
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
params: FluxParams
|
||||
ae_params: AutoEncoderParams
|
||||
ckpt_path: str | None
|
||||
ae_path: str | None
|
||||
repo_id: str | None
|
||||
repo_flow: str | None
|
||||
repo_ae: str | None
|
||||
|
||||
|
||||
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||
"flux-dev": 512,
|
||||
"flux-schnell": 256,
|
||||
}
|
||||
|
||||
|
||||
ae_params = {
|
||||
"flux": AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
params = {
|
||||
"flux-dev": FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
"flux-schnell": FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||
}
|
||||
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
width=518,
|
||||
height=518,
|
||||
resize_target=False,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=14,
|
||||
resize_method="lower_bound",
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DepthAnythingDetector:
|
||||
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
@staticmethod
|
||||
def load_model(
|
||||
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
|
||||
) -> DPT_DINOv2:
|
||||
match model_size:
|
||||
case "small":
|
||||
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
case "base":
|
||||
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
|
||||
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||
if not self.model:
|
||||
logger.warn("DepthAnything model was not loaded. Returning original image")
|
||||
return image
|
||||
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
np_image = np_image[:, :, ::-1] / 255.0
|
||||
|
||||
image_height, image_width = np_image.shape[:2]
|
||||
np_image = transform({"image": np_image})["image"]
|
||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
depth = self.model(tensor_image)
|
||||
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
|
||||
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
||||
|
||||
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
|
||||
depth_map = Image.fromarray(depth_map)
|
||||
|
||||
new_height = int(image_height * (resolution / image_width))
|
||||
depth_map = depth_map.resize((resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
@@ -0,0 +1,31 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class DepthAnythingPipeline(RawModel):
|
||||
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
|
||||
for Invoke's Model Management System"""
|
||||
|
||||
def __init__(self, pipeline: DepthEstimationPipeline) -> None:
|
||||
self._pipeline = pipeline
|
||||
|
||||
def generate_depth(self, image: Image.Image) -> Image.Image:
|
||||
depth_map = self._pipeline(image)["depth"]
|
||||
assert isinstance(depth_map, Image.Image)
|
||||
return depth_map
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||
device = None
|
||||
self._pipeline.model.to(device=device, dtype=dtype)
|
||||
self._pipeline.device = self._pipeline.model.device
|
||||
|
||||
def calc_size(self) -> int:
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._pipeline.model)
|
||||
@@ -1,145 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
||||
scratch = nn.Module()
|
||||
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape
|
||||
out_shape3 = out_shape
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape
|
||||
|
||||
if expand:
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape * 2
|
||||
out_shape3 = out_shape * 4
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape * 8
|
||||
|
||||
scratch.layer1_rn = nn.Conv2d(
|
||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer2_rn = nn.Conv2d(
|
||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer3_rn = nn.Conv2d(
|
||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
if len(in_shape) >= 4:
|
||||
scratch.layer4_rn = nn.Conv2d(
|
||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
|
||||
return scratch
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
"""Residual convolution module."""
|
||||
|
||||
def __init__(self, features, activation, bn):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.bn = bn
|
||||
|
||||
self.groups = 1
|
||||
|
||||
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||
|
||||
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||
|
||||
if self.bn:
|
||||
self.bn1 = nn.BatchNorm2d(features)
|
||||
self.bn2 = nn.BatchNorm2d(features)
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
if self.bn:
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
if self.bn:
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.groups > 1:
|
||||
out = self.conv_merge(out)
|
||||
|
||||
return self.skip_add.add(out, x)
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""Feature fusion block."""
|
||||
|
||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.groups = 1
|
||||
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand:
|
||||
out_features = features // 2
|
||||
|
||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
||||
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
self.size = size
|
||||
|
||||
def forward(self, *xs, size=None):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
if (size is None) and (self.size is None):
|
||||
modifier = {"scale_factor": 2}
|
||||
elif size is None:
|
||||
modifier = {"size": self.size}
|
||||
else:
|
||||
modifier = {"size": size}
|
||||
|
||||
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
||||
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
@@ -1,183 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.model.blocks import FeatureFusionBlock, _make_scratch
|
||||
|
||||
torchhub_path = Path(__file__).parent.parent / "torchhub"
|
||||
|
||||
|
||||
def _make_fusion_block(features, use_bn, size=None):
|
||||
return FeatureFusionBlock(
|
||||
features,
|
||||
nn.ReLU(False),
|
||||
deconv=False,
|
||||
bn=use_bn,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
class DPTHead(nn.Module):
|
||||
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
|
||||
super(DPTHead, self).__init__()
|
||||
|
||||
self.nclass = nclass
|
||||
self.use_clstoken = use_clstoken
|
||||
|
||||
self.projects = nn.ModuleList(
|
||||
[
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
for out_channel in out_channels
|
||||
]
|
||||
)
|
||||
|
||||
self.resize_layers = nn.ModuleList(
|
||||
[
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
||||
),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if use_clstoken:
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
|
||||
|
||||
self.scratch = _make_scratch(
|
||||
out_channels,
|
||||
features,
|
||||
groups=1,
|
||||
expand=False,
|
||||
)
|
||||
|
||||
self.scratch.stem_transpose = None
|
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
if nclass > 1:
|
||||
self.scratch.output_conv = nn.Sequential(
|
||||
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
else:
|
||||
self.scratch.output_conv1 = nn.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, out_features, patch_h, patch_w):
|
||||
out = []
|
||||
for i, x in enumerate(out_features):
|
||||
if self.use_clstoken:
|
||||
x, cls_token = x[0], x[1]
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
else:
|
||||
x = x[0]
|
||||
|
||||
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
||||
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
|
||||
out.append(x)
|
||||
|
||||
layer_1, layer_2, layer_3, layer_4 = out
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv1(path_1)
|
||||
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
||||
out = self.scratch.output_conv2(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DPT_DINOv2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
features,
|
||||
out_channels,
|
||||
encoder="vitl",
|
||||
use_bn=False,
|
||||
use_clstoken=False,
|
||||
):
|
||||
super(DPT_DINOv2, self).__init__()
|
||||
|
||||
assert encoder in ["vits", "vitb", "vitl"]
|
||||
|
||||
# # in case the Internet connection is not stable, please load the DINOv2 locally
|
||||
# if use_local:
|
||||
# self.pretrained = torch.hub.load(
|
||||
# torchhub_path / "facebookresearch_dinov2_main",
|
||||
# "dinov2_{:}14".format(encoder),
|
||||
# source="local",
|
||||
# pretrained=False,
|
||||
# )
|
||||
# else:
|
||||
# self.pretrained = torch.hub.load(
|
||||
# "facebookresearch/dinov2",
|
||||
# "dinov2_{:}14".format(encoder),
|
||||
# )
|
||||
|
||||
self.pretrained = torch.hub.load(
|
||||
"facebookresearch/dinov2",
|
||||
"dinov2_{:}14".format(encoder),
|
||||
)
|
||||
|
||||
dim = self.pretrained.blocks[0].attn.qkv.in_features
|
||||
|
||||
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
|
||||
|
||||
def forward(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
|
||||
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
|
||||
|
||||
patch_h, patch_w = h // 14, w // 14
|
||||
|
||||
depth = self.depth_head(features, patch_h, patch_w)
|
||||
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
|
||||
depth = F.relu(depth)
|
||||
|
||||
return depth.squeeze(1)
|
||||
@@ -1,227 +0,0 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
||||
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
||||
|
||||
Args:
|
||||
sample (dict): sample
|
||||
size (tuple): image size
|
||||
|
||||
Returns:
|
||||
tuple: new size
|
||||
"""
|
||||
shape = list(sample["disparity"].shape)
|
||||
|
||||
if shape[0] >= size[0] and shape[1] >= size[1]:
|
||||
return sample
|
||||
|
||||
scale = [0, 0]
|
||||
scale[0] = size[0] / shape[0]
|
||||
scale[1] = size[1] / shape[1]
|
||||
|
||||
scale = max(scale)
|
||||
|
||||
shape[0] = math.ceil(scale * shape[0])
|
||||
shape[1] = math.ceil(scale * shape[1])
|
||||
|
||||
# resize
|
||||
sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
|
||||
|
||||
sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
tuple(shape[::-1]),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
return tuple(shape)
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""Resize sample to given size (width, height)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
resize_target=True,
|
||||
keep_aspect_ratio=False,
|
||||
ensure_multiple_of=1,
|
||||
resize_method="lower_bound",
|
||||
image_interpolation_method=cv2.INTER_AREA,
|
||||
):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
width (int): desired output width
|
||||
height (int): desired output height
|
||||
resize_target (bool, optional):
|
||||
True: Resize the full sample (image, mask, target).
|
||||
False: Resize image only.
|
||||
Defaults to True.
|
||||
keep_aspect_ratio (bool, optional):
|
||||
True: Keep the aspect ratio of the input sample.
|
||||
Output sample might not have the given width and height, and
|
||||
resize behaviour depends on the parameter 'resize_method'.
|
||||
Defaults to False.
|
||||
ensure_multiple_of (int, optional):
|
||||
Output width and height is constrained to be multiple of this parameter.
|
||||
Defaults to 1.
|
||||
resize_method (str, optional):
|
||||
"lower_bound": Output will be at least as large as the given size.
|
||||
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller
|
||||
than given size.)
|
||||
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
||||
Defaults to "lower_bound".
|
||||
"""
|
||||
self.__width = width
|
||||
self.__height = height
|
||||
|
||||
self.__resize_target = resize_target
|
||||
self.__keep_aspect_ratio = keep_aspect_ratio
|
||||
self.__multiple_of = ensure_multiple_of
|
||||
self.__resize_method = resize_method
|
||||
self.__image_interpolation_method = image_interpolation_method
|
||||
|
||||
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
||||
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if max_val is not None and y > max_val:
|
||||
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if y < min_val:
|
||||
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
return y
|
||||
|
||||
def get_size(self, width, height):
|
||||
# determine new height and width
|
||||
scale_height = self.__height / height
|
||||
scale_width = self.__width / width
|
||||
|
||||
if self.__keep_aspect_ratio:
|
||||
if self.__resize_method == "lower_bound":
|
||||
# scale such that output size is lower bound
|
||||
if scale_width > scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "upper_bound":
|
||||
# scale such that output size is upper bound
|
||||
if scale_width < scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "minimal":
|
||||
# scale as least as possbile
|
||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
else:
|
||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||
|
||||
if self.__resize_method == "lower_bound":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
||||
elif self.__resize_method == "upper_bound":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
||||
elif self.__resize_method == "minimal":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width)
|
||||
else:
|
||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||
|
||||
return (new_width, new_height)
|
||||
|
||||
def __call__(self, sample):
|
||||
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
||||
|
||||
# resize sample
|
||||
sample["image"] = cv2.resize(
|
||||
sample["image"],
|
||||
(width, height),
|
||||
interpolation=self.__image_interpolation_method,
|
||||
)
|
||||
|
||||
if self.__resize_target:
|
||||
if "disparity" in sample:
|
||||
sample["disparity"] = cv2.resize(
|
||||
sample["disparity"],
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
|
||||
if "depth" in sample:
|
||||
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
if "semseg_mask" in sample:
|
||||
# sample["semseg_mask"] = cv2.resize(
|
||||
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
|
||||
# )
|
||||
sample["semseg_mask"] = F.interpolate(
|
||||
torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
|
||||
).numpy()[0, 0]
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
# sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
# print(sample['image'].shape, sample['depth'].shape)
|
||||
return sample
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""Normlize image by given mean and std."""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.__mean = mean
|
||||
self.__std = std
|
||||
|
||||
def __call__(self, sample):
|
||||
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class PrepareForNet(object):
|
||||
"""Prepare sample for usage as network input."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, sample):
|
||||
image = np.transpose(sample["image"], (2, 0, 1))
|
||||
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = sample["mask"].astype(np.float32)
|
||||
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
||||
|
||||
if "depth" in sample:
|
||||
depth = sample["depth"].astype(np.float32)
|
||||
sample["depth"] = np.ascontiguousarray(depth)
|
||||
|
||||
if "semseg_mask" in sample:
|
||||
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
|
||||
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
|
||||
|
||||
return sample
|
||||
@@ -1,630 +0,0 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
"""LoRA model support."""
|
||||
|
||||
import bisect
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
params = {"weight": self.get_weight(orig_module.weight)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias
|
||||
return params
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||
unknown_keys = set(values.keys()) - all_known_keys
|
||||
if unknown_keys:
|
||||
logger.warning(
|
||||
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||
)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
self.mid = values.get("lora_mid.weight", None)
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lora_up.weight",
|
||||
"lora_down.weight",
|
||||
"lora_mid.weight",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
self.t1 = values.get("hada_t1", None)
|
||||
self.t2 = values.get("hada_t2", None)
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"hada_w1_a",
|
||||
"hada_w1_b",
|
||||
"hada_w2_a",
|
||||
"hada_w2_b",
|
||||
"hada_t1",
|
||||
"hada_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1 = values.get("lokr_w1", None)
|
||||
if self.w1 is None:
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
self.w2 = values.get("lokr_w2", None)
|
||||
if self.w2 is None:
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
self.t2 = values.get("lokr_t2", None)
|
||||
|
||||
if self.w1_b is not None:
|
||||
self.rank = self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
self.rank = self.w2_b.shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lokr_w1",
|
||||
"lokr_w1_a",
|
||||
"lokr_w1_b",
|
||||
"lokr_w2",
|
||||
"lokr_w2_a",
|
||||
"lokr_w2_b",
|
||||
"lokr_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
self.bias = values.get("diff_b", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"diff", "diff_b"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"weight", "on_input"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
):
|
||||
self._name = name
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
) -> Self:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem,
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
sd = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
# lora and locon
|
||||
if "lora_up.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_a" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
@@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
@@ -66,7 +67,9 @@ class ModelType(str, Enum):
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
CLIPEmbed = "clip_embed"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
T5Encoder = "t5_encoder"
|
||||
SpandrelImageToImage = "spandrel_image_to_image"
|
||||
|
||||
|
||||
@@ -74,6 +77,7 @@ class SubModelType(str, Enum):
|
||||
"""Submodel type."""
|
||||
|
||||
UNet = "unet"
|
||||
Transformer = "transformer"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
@@ -104,6 +108,9 @@ class ModelFormat(str, Enum):
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
InvokeAI = "invokeai"
|
||||
T5Encoder = "t5_encoder"
|
||||
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
@@ -186,7 +193,9 @@ class ModelConfigBase(BaseModel):
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
|
||||
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
|
||||
)
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
converted_at: Optional[float] = Field(
|
||||
description="When this model was last converted to diffusers", default_factory=time.time
|
||||
@@ -205,6 +214,26 @@ class LoRAConfigBase(ModelConfigBase):
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
|
||||
|
||||
class T5EncoderConfigBase(ModelConfigBase):
|
||||
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
|
||||
|
||||
|
||||
class T5EncoderConfig(T5EncoderConfigBase):
|
||||
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
|
||||
|
||||
|
||||
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
|
||||
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
@@ -229,7 +258,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
@@ -268,7 +296,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
@@ -317,6 +344,21 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.format = ModelFormat.BnbQuantizednf4b
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
@@ -350,6 +392,17 @@ class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for Clip Embeddings."""
|
||||
|
||||
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
@@ -408,12 +461,15 @@ AnyModelConfig = Annotated[
|
||||
Union[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||
@@ -421,6 +477,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
@@ -66,12 +66,14 @@ class ModelLoader(ModelLoaderBase):
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
||||
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
config.path = str(self._get_model_path(config))
|
||||
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
@@ -83,7 +85,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
return self._ram_cache.get(
|
||||
key=config.key,
|
||||
submodel_type=submodel_type,
|
||||
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
|
||||
stats_name=stats_name,
|
||||
)
|
||||
|
||||
def get_size_fs(
|
||||
|
||||
@@ -128,7 +128,24 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
"""Return the maximum size the RAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self, value: float) -> float:
|
||||
"""Set the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -193,15 +210,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
|
||||
@@ -1,22 +1,6 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_cache_size=7.5)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
""" """
|
||||
|
||||
import gc
|
||||
import math
|
||||
@@ -40,53 +24,74 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Implementation of ModelCacheBase."""
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
@@ -128,6 +133,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the cap on vram cache size."""
|
||||
return self._max_vram_cache_size
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
def max_vram_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
self._max_vram_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
@@ -145,15 +160,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
return key in self._cached_models
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
@@ -203,7 +209,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
@@ -231,10 +237,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Move any unused models from VRAM."""
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
@@ -245,7 +254,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
@@ -303,7 +312,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
@@ -326,14 +335,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
@@ -353,17 +362,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
bytes_needed = size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GIG):.2f} GB"
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
@@ -380,7 +392,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
|
||||
234
invokeai/backend/model_manager/load/model_loaders/flux.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||
"""Class for Flux model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.util import ae_params, params
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
T5EncoderBnbQuantizedLlmInt8bConfig,
|
||||
T5EncoderConfig,
|
||||
VAECheckpointConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
try:
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
|
||||
bnb_available = True
|
||||
except ImportError:
|
||||
bnb_available = False
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class FluxVAELoader(ModelLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, VAECheckpointConfig):
|
||||
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = AutoEncoder(ae_params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
|
||||
class ClipCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CLIPEmbedDiffusersConfig):
|
||||
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer:
|
||||
return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer")
|
||||
case SubModelType.TextEncoder:
|
||||
return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder")
|
||||
|
||||
raise ValueError(
|
||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
|
||||
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig):
|
||||
raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.")
|
||||
if not bnb_available:
|
||||
raise ImportError(
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
)
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
te2_model_path = Path(config.path) / "text_encoder_2"
|
||||
model_config = AutoConfig.from_pretrained(te2_model_path)
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoModelForTextEncoding.from_config(model_config)
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
|
||||
|
||||
state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors"
|
||||
state_dict = load_file(state_dict_path)
|
||||
self._load_state_dict_into_t5(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
raise ValueError(
|
||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]):
|
||||
# There is a shared reference to a single weight tensor in the model.
|
||||
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
|
||||
# be present in the state_dict.
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
assert len(unexpected_keys) == 0
|
||||
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
|
||||
# Assert that the layers we expect to be shared are actually shared.
|
||||
assert model.encoder.embed_tokens.weight is model.shared.weight
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||
class T5EncoderCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, T5EncoderConfig):
|
||||
raise ValueError("Only T5EncoderConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
||||
|
||||
raise ValueError(
|
||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
class FluxCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = Flux(params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
|
||||
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
||||
if not bnb_available:
|
||||
raise ImportError(
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params[config.config_path])
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
@@ -78,7 +78,12 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
|
||||
# TO DO: Add exception handling
|
||||
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
||||
if module in ["diffusers", "transformers"]:
|
||||
if module in [
|
||||
"diffusers",
|
||||
"transformers",
|
||||
"invokeai.backend.quantization.fast_quantized_transformers_model",
|
||||
"invokeai.backend.quantization.fast_quantized_diffusion_model",
|
||||
]:
|
||||
res_type = sys.modules[module]
|
||||
else:
|
||||
res_type = sys.modules["diffusers"].pipelines
|
||||
|
||||
@@ -5,8 +5,10 @@ from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@@ -18,6 +20,11 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.peft.conversions.flux_kohya_lora_conversion_utils import (
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.peft.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.peft.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@@ -45,14 +52,28 @@ class LoRALoader(ModelLoader):
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
model_path = Path(config.path)
|
||||
assert self._model_base is not None
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
base_model=self._model_base,
|
||||
)
|
||||
|
||||
# Load the state dict from the model file.
|
||||
if model_path.suffix == ".safetensors":
|
||||
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# Apply state_dict key conversions, if necessary.
|
||||
if self._model_base == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
elif self._model_base == BaseModelType.Flux:
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
|
||||
|
||||
model.to(dtype=self._torch_dtype)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
self._model_base = config.base
|
||||
|
||||
@@ -36,8 +36,18 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
}
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
|
||||
)
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
|
||||
@@ -9,14 +9,15 @@ from typing import Optional
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
@@ -45,9 +46,21 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
SpandrelImageToImageModel,
|
||||
GroundingDinoPipeline,
|
||||
SegmentAnythingPipeline,
|
||||
DepthAnythingPipeline,
|
||||
),
|
||||
):
|
||||
return model.calc_size()
|
||||
elif isinstance(
|
||||
model,
|
||||
(
|
||||
T5TokenizerFast,
|
||||
T5Tokenizer,
|
||||
),
|
||||
):
|
||||
# HACK(ryand): len(model) just returns the vocabulary size, so this is blatantly wrong. It should be small
|
||||
# relative to the text encoder that it's used with, so shouldn't matter too much, but we should fix this at some
|
||||
# point.
|
||||
return len(model)
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
# supported model types.
|
||||
|
||||
@@ -26,6 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.peft.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
@@ -95,6 +96,7 @@ class ModelProbe(object):
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"FluxPipeline": ModelType.Main,
|
||||
"StableDiffusionPipeline": ModelType.Main,
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
@@ -106,6 +108,7 @@ class ModelProbe(object):
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
"CLIPModel": ModelType.CLIPEmbed,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -161,7 +164,7 @@ class ModelProbe(object):
|
||||
fields["description"] = (
|
||||
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||
|
||||
fields["default_settings"] = fields.get("default_settings")
|
||||
@@ -176,10 +179,10 @@ class ModelProbe(object):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if (
|
||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
||||
and fields["format"] is ModelFormat.Checkpoint
|
||||
):
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
|
||||
ModelFormat.Checkpoint,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
]:
|
||||
ckpt_config_path = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
model_type=fields["type"],
|
||||
@@ -222,7 +225,8 @@ class ModelProbe(object):
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
|
||||
# Keys starting with double_blocks are associated with Flux models
|
||||
return ModelType.Main
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
@@ -321,10 +325,27 @@ class ModelProbe(object):
|
||||
return possible_conf.absolute()
|
||||
|
||||
if model_type is ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
if base_type == BaseModelType.Flux:
|
||||
# TODO: Decide between dev/schnell
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
if "guidance_in.out_layer.weight" in state_dict:
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-dev"
|
||||
else:
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-schnell"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
elif model_type is ModelType.ControlNet:
|
||||
config_file = (
|
||||
"controlnet/cldm_v15.yaml"
|
||||
@@ -333,7 +354,13 @@ class ModelProbe(object):
|
||||
)
|
||||
elif model_type is ModelType.VAE:
|
||||
config_file = (
|
||||
"stable-diffusion/v1-inference.yaml"
|
||||
# For flux, this is a key in invokeai.backend.flux.util.ae_params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
"flux"
|
||||
if base_type is BaseModelType.Flux
|
||||
else "stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "stable-diffusion/sd_xl_base.yaml"
|
||||
if base_type is BaseModelType.StableDiffusionXL
|
||||
@@ -416,11 +443,15 @@ class CheckpointProbeBase(ProbeBase):
|
||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
|
||||
return ModelFormat.BnbQuantizednf4b
|
||||
return ModelFormat("checkpoint")
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||
if model_type != ModelType.Main:
|
||||
base_type = self.get_base_type()
|
||||
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
@@ -440,6 +471,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
|
||||
return BaseModelType.Flux
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
@@ -482,6 +515,7 @@ class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
(r"xl", BaseModelType.StableDiffusionXL),
|
||||
(r"sd2", BaseModelType.StableDiffusion2),
|
||||
(r"vae", BaseModelType.StableDiffusion1),
|
||||
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
|
||||
]:
|
||||
if re.search(regexp, self.model_path.name, re.IGNORECASE):
|
||||
return basetype
|
||||
@@ -495,9 +529,11 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
return ModelFormat("lycoris")
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint):
|
||||
return BaseModelType.Flux
|
||||
|
||||
# If we've gotten here, we assume that the model is a Stable Diffusion model.
|
||||
token_vector_length = lora_token_vector_length(self.checkpoint)
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
@@ -713,6 +749,11 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||
|
||||
|
||||
class T5EncoderFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat.T5Encoder
|
||||
|
||||
|
||||
class ONNXFolderProbe(PipelineFolderProbe):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# Due to the way the installer is set up, the configuration file for safetensors
|
||||
@@ -805,6 +846,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class CLIPEmbedFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
@@ -835,8 +881,10 @@ ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
||||
|
||||
|
||||
class StarterModelWithoutDependencies(BaseModel):
|
||||
@@ -11,6 +11,7 @@ class StarterModelWithoutDependencies(BaseModel):
|
||||
name: str
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
format: Optional[ModelFormat] = None
|
||||
is_installed: bool = False
|
||||
|
||||
|
||||
@@ -51,10 +52,76 @@ cyberrealistic_negative = StarterModel(
|
||||
type=ModelType.TextualInversion,
|
||||
)
|
||||
|
||||
t5_base_encoder = StarterModel(
|
||||
name="t5_base_encoder",
|
||||
base=BaseModelType.Any,
|
||||
source="InvokeAI/t5-v1_1-xxl::bfloat16",
|
||||
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
|
||||
type=ModelType.T5Encoder,
|
||||
)
|
||||
|
||||
t5_8b_quantized_encoder = StarterModel(
|
||||
name="t5_bnb_int8_quantized_encoder",
|
||||
base=BaseModelType.Any,
|
||||
source="InvokeAI/t5-v1_1-xxl::bnb_llm_int8",
|
||||
description="T5-XXL text encoder with bitsandbytes LLM.int8() quantization (used in FLUX pipelines). ~5GB",
|
||||
type=ModelType.T5Encoder,
|
||||
format=ModelFormat.BnbQuantizedLlmInt8b,
|
||||
)
|
||||
|
||||
clip_l_encoder = StarterModel(
|
||||
name="clip-vit-large-patch14",
|
||||
base=BaseModelType.Any,
|
||||
source="InvokeAI/clip-vit-large-patch14-text-encoder::bfloat16",
|
||||
description="CLIP-L text encoder (used in FLUX pipelines). ~250MB",
|
||||
type=ModelType.CLIPEmbed,
|
||||
)
|
||||
|
||||
flux_vae = StarterModel(
|
||||
name="FLUX.1-schnell_ae",
|
||||
base=BaseModelType.Flux,
|
||||
source="black-forest-labs/FLUX.1-schnell::ae.safetensors",
|
||||
description="FLUX VAE compatible with both schnell and dev variants.",
|
||||
type=ModelType.VAE,
|
||||
)
|
||||
|
||||
|
||||
# List of starter models, displayed on the frontend.
|
||||
# The order/sort of this list is not changed by the frontend - set it how you want it here.
|
||||
STARTER_MODELS: list[StarterModel] = [
|
||||
# region: Main
|
||||
StarterModel(
|
||||
name="FLUX Schnell (Quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
|
||||
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
),
|
||||
StarterModel(
|
||||
name="FLUX Dev (Quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
|
||||
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
),
|
||||
StarterModel(
|
||||
name="FLUX Schnell",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
|
||||
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
),
|
||||
StarterModel(
|
||||
name="FLUX Dev",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
|
||||
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
),
|
||||
StarterModel(
|
||||
name="CyberRealistic v4.1",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
@@ -125,6 +192,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
# endregion
|
||||
# region VAE
|
||||
sdxl_fp16_vae_fix,
|
||||
flux_vae,
|
||||
# endregion
|
||||
# region LoRA
|
||||
StarterModel(
|
||||
@@ -450,6 +518,11 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
# endregion
|
||||
# region TextEncoders
|
||||
t5_base_encoder,
|
||||
t5_8b_quantized_encoder,
|
||||
clip_l_encoder,
|
||||
# endregion
|
||||
]
|
||||
|
||||
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"
|
||||
|
||||
@@ -54,6 +54,7 @@ def filter_files(
|
||||
"lora_weights.safetensors",
|
||||
"weights.pb",
|
||||
"onnx_data",
|
||||
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
|
||||
)
|
||||
):
|
||||
paths.append(file)
|
||||
@@ -62,13 +63,13 @@ def filter_files(
|
||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||
# will adhere to this naming convention, so this is an area to be careful of.
|
||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
paths.append(file)
|
||||
|
||||
# limit search to subfolder if requested
|
||||
if subfolder:
|
||||
subfolder = root / subfolder
|
||||
paths = [x for x in paths if x.parent == Path(subfolder)]
|
||||
paths = [x for x in paths if Path(subfolder) in x.parents]
|
||||
|
||||
# _filter_by_variant uniquifies the paths and returns a set
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
@@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
if variant == ModelRepoVariant.Flax:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
# Note: '.model' was added to support:
|
||||
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
|
||||
elif path.suffix in [".json", ".txt", ".model"]:
|
||||
result.add(path)
|
||||
|
||||
elif variant in [
|
||||
@@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
continue
|
||||
|
||||
for candidate_list in subfolder_weights.values():
|
||||
# Check if at least one of the files has the explicit fp16 variant.
|
||||
at_least_one_fp16 = False
|
||||
for candidate in candidate_list:
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
|
||||
at_least_one_fp16 = True
|
||||
break
|
||||
|
||||
if not at_least_one_fp16:
|
||||
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
|
||||
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
|
||||
# we'll simply keep all the candidates. An example of a model that hits this case is
|
||||
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
|
||||
for candidate in candidate_list:
|
||||
result.add(candidate.path)
|
||||
|
||||
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
|
||||
# candidate.
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
||||
@@ -13,10 +13,10 @@ from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
0
invokeai/backend/peft/__init__.py
Normal file
0
invokeai/backend/peft/conversions/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import re
|
||||
from typing import Any, Dict, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.peft.layers.utils import peft_layer_from_state_dict
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
|
||||
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_unet_double_blocks_0_img_attn_proj.alpha
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
|
||||
FLUX_KOHYA_KEY_REGEX = (
|
||||
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
)
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
for k in state_dict.keys():
|
||||
if not re.match(FLUX_KOHYA_KEY_REGEX, k):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
layer_name, param_name = key.split(".", 1)
|
||||
if layer_name not in grouped_state_dict:
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Convert the state dict to the InvokeAI format.
|
||||
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
layer = peft_layer_from_state_dict(layer_key, layer_state_dict)
|
||||
layers[layer_key] = layer
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
|
||||
|
||||
Example key conversions:
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"""
|
||||
|
||||
def replace_func(match: re.Match[str]) -> str:
|
||||
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
|
||||
if match.group(4):
|
||||
s += f".{match.group(4)}"
|
||||
return s
|
||||
|
||||
converted_dict: dict[str, T] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
|
||||
if match:
|
||||
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
|
||||
converted_dict[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
return converted_dict
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.peft.layers.utils import peft_layer_from_state_dict
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
|
||||
|
||||
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, values in grouped_state_dict.items():
|
||||
layer = peft_layer_from_state_dict(layer_key, values)
|
||||
layers[layer_key] = layer
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
154
invokeai/backend/peft/conversions/sdxl_lora_conversion_utils.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import bisect
|
||||
from typing import Dict, List, Tuple, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict: dict[str, T] = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer: list[tuple[str, str]] = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map: list[tuple[str, str]] = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
|
||||
}
|
||||
0
invokeai/backend/peft/layers/__init__.py
Normal file
10
invokeai/backend/peft/layers/any_lora_layer.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Union
|
||||
|
||||
from invokeai.backend.peft.layers.full_layer import FullLayer
|
||||
from invokeai.backend.peft.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.peft.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.peft.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.peft.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.peft.layers.norm_layer import NormLayer
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
||||
37
invokeai/backend/peft/layers/full_layer.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
self.bias = values.get("diff_b", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"diff", "diff_b"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
42
invokeai/backend/peft/layers/ia3_layer.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"weight", "on_input"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||