mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Compare commits
168 Commits
ryan/lora-
...
psyche/fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
543c152e10 | ||
|
|
868e06eb8b | ||
|
|
40e4dbe1fb | ||
|
|
4815b4ea80 | ||
|
|
d77a6ccd76 | ||
|
|
3e860c8338 | ||
|
|
4f2ef7ce76 | ||
|
|
d7e9ad52f9 | ||
|
|
b6d7a44004 | ||
|
|
e18100ae7e | ||
|
|
ad0aa0e6b2 | ||
|
|
157b92e0fd | ||
|
|
fd838ad9d4 | ||
|
|
5e9227c052 | ||
|
|
94785231ce | ||
|
|
b46d7abfb0 | ||
|
|
9a0a226ce1 | ||
|
|
477d87ec31 | ||
|
|
8b4b0ff0cf | ||
|
|
6fd9b0a274 | ||
|
|
52fc5a64d4 | ||
|
|
a8bef59699 | ||
|
|
6d49ee839c | ||
|
|
0525f967c2 | ||
|
|
2855bb6b41 | ||
|
|
20acfc9a00 | ||
|
|
918f541af8 | ||
|
|
93e76b61d6 | ||
|
|
f692e217ea | ||
|
|
f2981979f9 | ||
|
|
ef970a1cdc | ||
|
|
5ee7405f97 | ||
|
|
e24e386a27 | ||
|
|
b06d61e3c0 | ||
|
|
6bf5b747ce | ||
|
|
7d6ab0ceb2 | ||
|
|
9692a36dd6 | ||
|
|
b0b699a01f | ||
|
|
a8b2c4c3d2 | ||
|
|
03944191db | ||
|
|
987c9ae076 | ||
|
|
6d7314ac0a | ||
|
|
80db9537ff | ||
|
|
6f926f05b0 | ||
|
|
61253b91f1 | ||
|
|
0148512038 | ||
|
|
d0f35fceed | ||
|
|
cefcb340d9 | ||
|
|
0fc538734b | ||
|
|
7214d4969b | ||
|
|
a83a999b79 | ||
|
|
f8a6accf8a | ||
|
|
f8ab414f99 | ||
|
|
c6795a1b47 | ||
|
|
0a8fc74ae9 | ||
|
|
dc54e8763b | ||
|
|
1b56020876 | ||
|
|
3f990393a1 | ||
|
|
97d56f7dc9 | ||
|
|
fe0ef2c27c | ||
|
|
65fcbf5f60 | ||
|
|
d3916dbdb6 | ||
|
|
55b13c1da3 | ||
|
|
7dc3e0fdbe | ||
|
|
a39bcf7e85 | ||
|
|
a7c72992a6 | ||
|
|
d30a9ced38 | ||
|
|
e0bfa6157b | ||
|
|
83ea6420e2 | ||
|
|
ce11a1952e | ||
|
|
e48dee4c4a | ||
|
|
712674b6dd | ||
|
|
de0043f443 | ||
|
|
d21506da6f | ||
|
|
a49894901a | ||
|
|
e7e26c8a93 | ||
|
|
9adcd2cc31 | ||
|
|
f9edd009f5 | ||
|
|
91a4160e36 | ||
|
|
9c9cec1b43 | ||
|
|
948ecf9333 | ||
|
|
1038f7bcab | ||
|
|
c7d9e2d62a | ||
|
|
11c3a2e15d | ||
|
|
9e3ca383ec | ||
|
|
bda83c2634 | ||
|
|
525cb38c71 | ||
|
|
a9a6720bad | ||
|
|
858bf9cf8c | ||
|
|
74a29c3735 | ||
|
|
6fc6be3aa0 | ||
|
|
174ea021a6 | ||
|
|
50b804e087 | ||
|
|
23270d7dfe | ||
|
|
39e6f6d53f | ||
|
|
c154d833b9 | ||
|
|
899a00af62 | ||
|
|
7c9ecdb362 | ||
|
|
4a5255611b | ||
|
|
b5b39db304 | ||
|
|
2cb5743cc5 | ||
|
|
64ee8d491e | ||
|
|
d70d48de45 | ||
|
|
3f8636330f | ||
|
|
0c2f96daf1 | ||
|
|
c9b2cce627 | ||
|
|
401fb392b8 | ||
|
|
594511cf4a | ||
|
|
d764aa4a2a | ||
|
|
ea34726329 | ||
|
|
9b615e0de7 | ||
|
|
a463e97269 | ||
|
|
b272d46056 | ||
|
|
4d5f74c05b | ||
|
|
dd09509dbd | ||
|
|
7fad4c9491 | ||
|
|
b820862eab | ||
|
|
c604a0956e | ||
|
|
9369b39a12 | ||
|
|
80f64abd1e | ||
|
|
37e3089457 | ||
|
|
fe09f2d27a | ||
|
|
e7e3f7e144 | ||
|
|
606d58d7db | ||
|
|
c76a448846 | ||
|
|
46133b5656 | ||
|
|
ac28370fd2 | ||
|
|
1e0552c813 | ||
|
|
e2451ef5ca | ||
|
|
443d838fd0 | ||
|
|
3a8a5442ea | ||
|
|
808e3770d3 | ||
|
|
2b441d6a2d | ||
|
|
58de93a89e | ||
|
|
1eede4315e | ||
|
|
8ea697d733 | ||
|
|
693d42661c | ||
|
|
41664f88db | ||
|
|
42f8d6aa11 | ||
|
|
5f41a69665 | ||
|
|
7da90a9b6b | ||
|
|
440185cc40 | ||
|
|
26edc71268 | ||
|
|
a4bed7aee3 | ||
|
|
5fcd76a712 | ||
|
|
516ffa641c | ||
|
|
d84adfd39f | ||
|
|
ac82f73dbe | ||
|
|
70811d0bd0 | ||
|
|
e0344a302c | ||
|
|
92b0d89b70 | ||
|
|
da213e4638 | ||
|
|
246b59f148 | ||
|
|
046d19446c | ||
|
|
040551d4fb | ||
|
|
f53da60b84 | ||
|
|
5a035dd19f | ||
|
|
f3b253987f | ||
|
|
25ff7918e8 | ||
|
|
09fc60acb0 | ||
|
|
6f55f2c723 | ||
|
|
03b815c884 | ||
|
|
9cecdd17eb | ||
|
|
6b0f7ab57c | ||
|
|
c805e38da2 | ||
|
|
2c1de0f07d | ||
|
|
261d5ab488 | ||
|
|
ca571cd7a9 |
85
.github/workflows/typegen-checks.yml
vendored
Normal file
85
.github/workflows/typegen-checks.yml
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
# Runs typegen schema quality checks.
|
||||
# Frontend types should match the server.
|
||||
#
|
||||
# Checks for changes to files before running the checks.
|
||||
# If always_run is true, always runs the checks.
|
||||
|
||||
name: 'typegen checks'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the checks'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_call:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the checks'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
jobs:
|
||||
typegen-checks:
|
||||
runs-on: ubuntu-22.04
|
||||
timeout-minutes: 15 # expected run time: <5 min
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
src:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install python dependencies
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip3 install --use-pep517 --editable="."
|
||||
|
||||
- name: install frontend dependencies
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: copy schema
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: cp invokeai/frontend/web/src/services/api/schema.ts invokeai/frontend/web/src/services/api/schema_orig.ts
|
||||
shell: bash
|
||||
|
||||
- name: generate schema
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: make frontend-typegen
|
||||
shell: bash
|
||||
|
||||
- name: compare files
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: |
|
||||
if ! diff invokeai/frontend/web/src/services/api/schema.ts invokeai/frontend/web/src/services/api/schema_orig.ts; then
|
||||
echo "Files are different!";
|
||||
exit 1;
|
||||
fi
|
||||
shell: bash
|
||||
45
README.md
45
README.md
@@ -30,51 +30,12 @@ Invoke is available in two editions:
|
||||
|----------------------------------------------------------------------------------------------------------------------------|
|
||||
| [Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs] |
|
||||
|
||||
</div>
|
||||
# Installation
|
||||
|
||||
## Quick Start
|
||||
To get started with Invoke, [Download the Installer](https://www.invoke.com/downloads).
|
||||
|
||||
1. Download and unzip the installer from the bottom of the [latest release][latest release link].
|
||||
2. Run the installer script.
|
||||
For detailed step by step instructions, or for instructions on manual/docker installations, visit our documentation on [Installation and Updates][installation docs]
|
||||
|
||||
- **Windows**: Double-click on the `install.bat` script.
|
||||
- **macOS**: Open a Terminal window, drag the file `install.sh` from Finder into the Terminal, and press enter.
|
||||
- **Linux**: Run `install.sh`.
|
||||
|
||||
3. When prompted, enter a location for the install and select your GPU type.
|
||||
4. Once the install finishes, find the directory you selected during install. The default location is `C:\Users\Username\invokeai` for Windows or `~/invokeai` for Linux/macOS.
|
||||
5. Run the launcher script (`invoke.bat` for Windows, `invoke.sh` for macOS and Linux) the same way you ran the installer script in step 2.
|
||||
6. Select option 1 to start the application. Once it starts up, open your browser and go to <http://localhost:9090>.
|
||||
7. Open the model manager tab to install a starter model and then you'll be ready to generate.
|
||||
|
||||
More detail, including hardware requirements and manual install instructions, are available in the [installation documentation][installation docs].
|
||||
|
||||
## Docker Container
|
||||
|
||||
We publish official container images in Github Container Registry: https://github.com/invoke-ai/InvokeAI/pkgs/container/invokeai. Both CUDA and ROCm images are available. Check the above link for relevant tags.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Ensure that Docker is set up to use the GPU. Refer to [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] documentation.
|
||||
|
||||
### Generate!
|
||||
|
||||
Run the container, modifying the command as necessary:
|
||||
|
||||
```bash
|
||||
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
|
||||
```
|
||||
|
||||
Then open `http://localhost:9090` and install some models using the Model Manager tab to begin generating.
|
||||
|
||||
For ROCm, add `--device /dev/kfd --device /dev/dri` to the `docker run` command.
|
||||
|
||||
### Persist your data
|
||||
|
||||
You will likely want to persist your workspace outside of the container. Use the `--volume /home/myuser/invokeai:/invokeai` flag to mount some local directory (using its **absolute** path) to the `/invokeai` path inside the container. Your generated images and models will reside there. You can use this directory with other InvokeAI installations, or switch between runtime directories as needed.
|
||||
|
||||
### DIY
|
||||
|
||||
Build your own image and customize the environment to match your needs using our `docker-compose` stack. See [README.md](./docker/README.md) in the [docker](./docker) directory.
|
||||
|
||||
## Troubleshooting, FAQ and Support
|
||||
|
||||
|
||||
@@ -114,6 +114,10 @@ remote_api_tokens:
|
||||
|
||||
The provided token will be added as a `Bearer` token to the network requests to download the model files. As far as we know, this works for all model marketplaces that require authorization.
|
||||
|
||||
!!! tip "HuggingFace Models"
|
||||
|
||||
If you get an error when installing a HF model using a URL instead of repo id, you may need to [set up a HF API token](https://huggingface.co/settings/tokens) and add an entry for it under `remote_api_tokens`. Use `huggingface.co` for `url_regex`.
|
||||
|
||||
#### Model Hashing
|
||||
|
||||
Models are hashed during installation, providing a stable identifier for models across all platforms. Hashing is a one-time operation.
|
||||
|
||||
@@ -1364,7 +1364,6 @@ the in-memory loaded model:
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
| `model` | AnyModel | The instantiated model (details below) |
|
||||
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
||||
|
||||
### get_model_by_key(key, [submodel]) -> LoadedModel
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# Dev Environment
|
||||
|
||||
To make changes to Invoke's backend, frontend, or documentation, you'll need to set up a dev environment.
|
||||
To make changes to Invoke's backend, frontend or documentation, you'll need to set up a dev environment.
|
||||
|
||||
If you just want to use Invoke, you should use the [installer][installer link].
|
||||
If you only want to make changes to the docs site, you can skip the frontend dev environment setup as described in the below guide.
|
||||
|
||||
!!! info "Why do I need the frontend toolchain?"
|
||||
|
||||
The repo doesn't contain a build of the frontend. You'll be responsible for rebuilding it every time you pull in new changes, or run it in dev mode (which incurs a substantial performance penalty).
|
||||
If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
|
||||
!!! warning
|
||||
|
||||
@@ -17,84 +15,66 @@ If you just want to use Invoke, you should use the [installer][installer link].
|
||||
## Setup
|
||||
|
||||
1. Run through the [requirements][requirements link].
|
||||
|
||||
2. [Fork and clone][forking link] the [InvokeAI repo][repo link].
|
||||
|
||||
3. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
|
||||
4. Create a python virtual environment inside the directory you just created:
|
||||
|
||||
4. Follow the [manual install][manual install link] guide, with some modifications to the install command:
|
||||
|
||||
- Use `.` instead of `invokeai` to install from the current directory.
|
||||
|
||||
- Add `-e` after the `install` operation to make this an [editable install][editable install link]. That means your changes to the python code will be reflected when you restart the Invoke server.
|
||||
|
||||
- When installing the `invokeai` package, add the `dev`, `test` and `docs` package options to the package specifier. You may or may not need the `xformers` option - follow the manual install guide to figure that out. So, your package specifier will be either `".[dev,test,docs]"` or `".[dev,test,docs,xformers]"`. Note the quotes!
|
||||
|
||||
With the modifications made, the install command should look something like this:
|
||||
|
||||
```sh
|
||||
python3 -m venv .venv --prompt InvokeAI-Dev
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.11 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
|
||||
```
|
||||
|
||||
5. Activate the venv (you'll need to do this every time you want to run the app):
|
||||
5. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
|
||||
This is because the UI build is not distributed with the source code. You need to build it manually. End the running server instance.
|
||||
|
||||
If you only want to edit the docs, you can stop here and skip to the **Documentation** section below.
|
||||
|
||||
6. Install the frontend dev toolchain:
|
||||
|
||||
- [`nodejs`](https://nodejs.org/) (v20+)
|
||||
|
||||
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
|
||||
|
||||
7. Do a production build of the frontend:
|
||||
|
||||
```sh
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
6. Install the repo as an [editable install][editable install link]:
|
||||
|
||||
```sh
|
||||
pip install -e ".[dev,test,xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
Refer to the [manual installation][manual install link] instructions for more determining the correct install options. `xformers` is optional, but `dev` and `test` are not.
|
||||
|
||||
7. Install the frontend dev toolchain:
|
||||
|
||||
- [`nodejs`](https://nodejs.org/) (recommend v20 LTS)
|
||||
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
|
||||
|
||||
8. Do a production build of the frontend:
|
||||
|
||||
```sh
|
||||
cd PATH_TO_INVOKEAI_REPO/invokeai/frontend/web
|
||||
cd <PATH_TO_INVOKEAI_REPO>/invokeai/frontend/web
|
||||
pnpm i
|
||||
pnpm build
|
||||
```
|
||||
|
||||
9. Start the application:
|
||||
|
||||
```sh
|
||||
cd PATH_TO_INVOKEAI_REPO
|
||||
python scripts/invokeai-web.py
|
||||
```
|
||||
|
||||
10. Access the UI at `localhost:9090`.
|
||||
8. Restart the server and navigate to the URL. You should get a UI. After making changes to the python code, restart the server to see those changes.
|
||||
|
||||
## Updating the UI
|
||||
|
||||
You'll need to run `pnpm build` every time you pull in new changes. Another option is to skip the build and instead run the app in dev mode:
|
||||
You'll need to run `pnpm build` every time you pull in new changes.
|
||||
|
||||
Another option is to skip the build and instead run the UI in dev mode:
|
||||
|
||||
```sh
|
||||
pnpm dev
|
||||
```
|
||||
|
||||
This starts a dev server at `localhost:5173`, which you will use instead of `localhost:9090`.
|
||||
This starts a vite dev server for the UI at `127.0.0.1:5173`, which you will use instead of `127.0.0.1:9090`.
|
||||
|
||||
The dev mode is substantially slower than the production build but may be more convenient if you just need to test things out.
|
||||
The dev mode is substantially slower than the production build but may be more convenient if you just need to test things out. It will hot-reload the UI as you make changes to the frontend code. Sometimes the hot-reload doesn't work, and you need to manually refresh the browser tab.
|
||||
|
||||
## Documentation
|
||||
|
||||
The documentation is built with `mkdocs`. To preview it locally, you need a additional set of packages installed.
|
||||
The documentation is built with `mkdocs`. It provides a hot-reload dev server for the docs. Start it with `mkdocs serve`.
|
||||
|
||||
```sh
|
||||
# after activating the venv
|
||||
pip install -e ".[docs]"
|
||||
```
|
||||
|
||||
Then, you can start a live docs dev server, which will auto-refresh when you edit the docs:
|
||||
|
||||
```sh
|
||||
mkdocs serve
|
||||
```
|
||||
|
||||
On macOS and Linux, there is a `make` target for this:
|
||||
|
||||
```sh
|
||||
make docs
|
||||
```
|
||||
|
||||
[installer link]: ../installation/installer.md
|
||||
[launcher link]: ../installation/quick_start.md
|
||||
[forking link]: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/fork-a-repo
|
||||
[requirements link]: ../installation/requirements.md
|
||||
[repo link]: https://github.com/invoke-ai/InvokeAI
|
||||
|
||||
@@ -50,11 +50,9 @@ title: Invoke
|
||||
|
||||
## Installation
|
||||
|
||||
The [installer script](installation/installer.md) is the easiest way to install and update the application.
|
||||
The [Invoke Launcher](installation/quick_start.md) is the easiest way to install, update and run Invoke on Windows, macOS and Linux.
|
||||
|
||||
You can also install Invoke as python package [via PyPI](installation/manual.md) or [docker](installation/docker.md).
|
||||
|
||||
See the [installation section](./installation/index.md) for more information.
|
||||
You can also install Invoke as [python package](installation/manual.md) or with [docker](installation/docker.md).
|
||||
|
||||
## Help
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ title: Docker
|
||||
|
||||
!!! warning "macOS users"
|
||||
|
||||
Docker can not access the GPU on macOS, so your generation speeds will be slow. Use the [installer](./installer.md) instead.
|
||||
Docker can not access the GPU on macOS, so your generation speeds will be slow. Use the [launcher](./quick_start.md) instead.
|
||||
|
||||
!!! tip "Linux and Windows Users"
|
||||
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
# Installation and Updating Overview
|
||||
|
||||
Before installing, review the [installation requirements](./requirements.md) to ensure your system is set up properly.
|
||||
|
||||
See the [FAQ](../faq.md) for frequently-encountered installation issues.
|
||||
|
||||
If you need more help, join our [discord](https://discord.gg/ZmtBAhwWhy) or [create a GitHub issue](https://github.com/invoke-ai/InvokeAI/issues).
|
||||
|
||||
## Automated Installer & Updates
|
||||
|
||||
✅ The automated [installer](./installer.md) is the best way to install Invoke.
|
||||
|
||||
⬆️ The same installer is also the best way to update Invoke - simply rerun it for the same folder you installed to.
|
||||
|
||||
The installation process simply manages installation for the core libraries & application dependencies that run Invoke.
|
||||
|
||||
Models, images, or other assets in the Invoke root folder won't be affected by the installation process.
|
||||
|
||||
## Manual Install
|
||||
|
||||
If you are familiar with python and want more control over the packages that are installed, you can [install Invoke manually via PyPI](./manual.md).
|
||||
|
||||
Updates are managed by reinstalling the latest version through PyPi.
|
||||
|
||||
## Developer Install
|
||||
|
||||
If you want to contribute to InvokeAI, you'll need to set up a [dev environment](../contributing/dev-environment.md).
|
||||
|
||||
## Docker
|
||||
|
||||
Invoke publishes docker images. See the [docker installation guide](./docker.md) for details.
|
||||
|
||||
## Other Installation Guides
|
||||
|
||||
- [PyPatchMatch](./patchmatch.md)
|
||||
- [Installing Models](./models.md)
|
||||
@@ -1,4 +1,10 @@
|
||||
# Automatic Install & Updates
|
||||
# Legacy Scripts
|
||||
|
||||
!!! warning "Legacy Scripts"
|
||||
|
||||
We recommend using the Invoke Launcher to install and update Invoke. It's a desktop application for Windows, macOS and Linux. It takes care of a lot of nitty gritty details for you.
|
||||
|
||||
Follow the [quick start guide](./quick_start.md) to get started.
|
||||
|
||||
!!! tip "Use the installer to update"
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
|
||||
**Python experience is mandatory.**
|
||||
|
||||
If you want to use Invoke locally, you should probably use the [installer](./installer.md).
|
||||
If you want to use Invoke locally, you should probably use the [launcher](./quick_start.md).
|
||||
|
||||
If you want to contribute to Invoke, instead follow the [dev environment](../contributing/dev-environment.md) guide.
|
||||
If you want to contribute to Invoke or run the app on the latest dev branch, instead follow the [dev environment](../contributing/dev-environment.md) guide.
|
||||
|
||||
InvokeAI is distributed as a python package on PyPI, installable with `pip`. There are a few things that are handled by the installer and launcher that you'll need to manage manually, described in this guide.
|
||||
InvokeAI is distributed as a python package on PyPI, installable with `pip`. There are a few things that are handled by the launcher that you'll need to manage manually, described in this guide.
|
||||
|
||||
## Requirements
|
||||
|
||||
@@ -16,43 +16,39 @@ Before you start, go through the [installation requirements](./requirements.md).
|
||||
|
||||
## Walkthrough
|
||||
|
||||
1. Create a directory to contain your InvokeAI library, configuration files, and models. This is known as the "runtime" or "root" directory, and typically lives in your home directory under the name `invokeai`.
|
||||
We'll use [`uv`](https://github.com/astral-sh/uv) to install python and create a virtual environment, then install the `invokeai` package. `uv` is a modern, very fast alternative to `pip`.
|
||||
|
||||
The following commands vary depending on the version of Invoke being installed and the system onto which it is being installed.
|
||||
|
||||
1. Install `uv` as described in its [docs](https://docs.astral.sh/uv/getting-started/installation/#standalone-installer). We suggest using the standalone installer method.
|
||||
|
||||
Run `uv --version` to confirm that `uv` is installed and working. After installation, you may need to restart your terminal to get access to `uv`.
|
||||
|
||||
2. Create a directory for your installation, typically in your home directory (e.g. `~/invokeai` or `$Home/invokeai`):
|
||||
|
||||
=== "Linux/macOS"
|
||||
|
||||
```bash
|
||||
mkdir ~/invokeai
|
||||
cd ~/invokeai
|
||||
```
|
||||
|
||||
=== "Windows (PowerShell)"
|
||||
|
||||
```bash
|
||||
mkdir $Home/invokeai
|
||||
```
|
||||
|
||||
1. Enter the root directory and create a virtual Python environment within it named `.venv`.
|
||||
|
||||
!!! warning "Virtual Environment Location"
|
||||
|
||||
While you may create the virtual environment anywhere in the file system, we recommend that you create it within the root directory as shown here. This allows the application to automatically detect its data directories.
|
||||
|
||||
If you choose a different location for the venv, then you _must_ set the `INVOKEAI_ROOT` environment variable or specify the root directory using the `--root` CLI arg.
|
||||
|
||||
=== "Linux/macOS"
|
||||
|
||||
```bash
|
||||
cd ~/invokeai
|
||||
python3 -m venv .venv --prompt InvokeAI
|
||||
```
|
||||
|
||||
=== "Windows (PowerShell)"
|
||||
|
||||
```bash
|
||||
cd $Home/invokeai
|
||||
python3 -m venv .venv --prompt InvokeAI
|
||||
```
|
||||
|
||||
1. Activate the new environment:
|
||||
3. Create a virtual environment in that directory:
|
||||
|
||||
```sh
|
||||
uv venv --relocatable --prompt invoke --python 3.11 --python-preference only-managed .venv
|
||||
```
|
||||
|
||||
This command creates a portable virtual environment at `.venv` complete with a portable python 3.11. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
|
||||
|
||||
4. Activate the virtual environment:
|
||||
|
||||
=== "Linux/macOS"
|
||||
|
||||
@@ -60,41 +56,48 @@ Before you start, go through the [installation requirements](./requirements.md).
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
=== "Windows"
|
||||
=== "Windows (PowerShell)"
|
||||
|
||||
```ps
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
|
||||
!!! info "Permissions Error (Windows)"
|
||||
5. Choose a version to install. Review the [GitHub releases page](https://github.com/invoke-ai/InvokeAI/releases).
|
||||
|
||||
If you get a permissions error at this point, run this command and try again.
|
||||
6. Determine the package package specifier to use when installing. This is a performance optimization.
|
||||
|
||||
`Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser`
|
||||
- If you have an Nvidia 20xx series GPU or older, use `invokeai[xformers]`.
|
||||
- If you have an Nvidia 30xx series GPU or newer, or do not have an Nvidia GPU, use `invokeai`.
|
||||
|
||||
The command-line prompt should change to to show `(InvokeAI)`, indicating the venv is active.
|
||||
7. Determine the `PyPI` index URL to use for installation, if any. This is necessary to get the right version of torch installed.
|
||||
|
||||
1. Make sure that pip is installed in your virtual environment and up to date:
|
||||
=== "Invoke v5 or later"
|
||||
|
||||
```bash
|
||||
python3 -m pip install --upgrade pip
|
||||
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm62`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
=== "Invoke v4"
|
||||
|
||||
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm52`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
8. Install the `invokeai` package. Substitute the package specifier and version.
|
||||
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>=<VERSION> --python 3.11 --python-preference only-managed --force-reinstall
|
||||
```
|
||||
|
||||
1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features.
|
||||
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
|
||||
|
||||
- You may need to provide an [extra index URL](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-extra-index-url). Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command.
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>=<VERSION> --python 3.11 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
|
||||
```
|
||||
|
||||
```bash
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
- If you have a CUDA GPU and want to install with `xformers`, you need to add an option to the package name. Note that `xformers` is not strictly necessary. PyTorch includes an implementation of the SDP attention algorithm with similar performance for most GPUs.
|
||||
|
||||
```bash
|
||||
pip install "InvokeAI[xformers]" --use-pep517
|
||||
```
|
||||
|
||||
1. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:
|
||||
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:
|
||||
|
||||
=== "Linux/macOS"
|
||||
|
||||
@@ -102,17 +105,31 @@ Before you start, go through the [installation requirements](./requirements.md).
|
||||
deactivate && source .venv/bin/activate
|
||||
```
|
||||
|
||||
=== "Windows"
|
||||
=== "Windows (PowerShell)"
|
||||
|
||||
```ps
|
||||
deactivate
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
|
||||
1. Run the application:
|
||||
10. Run the application, specifying the directory you created earlier as the root directory:
|
||||
|
||||
Run `invokeai-web` to start the UI. You must activate the virtual environment before running the app.
|
||||
=== "Linux/macOS"
|
||||
|
||||
!!! warning
|
||||
```bash
|
||||
invokeai-web --root ~/invokeai
|
||||
```
|
||||
|
||||
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.
|
||||
=== "Windows (PowerShell)"
|
||||
|
||||
```bash
|
||||
invokeai-web --root $Home/invokeai
|
||||
```
|
||||
|
||||
## Headless Install and Launch Scripts
|
||||
|
||||
If you run Invoke on a headless server, you might want to install and run Invoke on the command line.
|
||||
|
||||
We do not plan to maintain scripts to do this moving forward, instead focusing our dev resources on the GUI [launcher](../installation/quick_start.md).
|
||||
|
||||
You can create your own scripts for this by copying the handful of commands in this guide. `uv`'s [`pip` interface docs](https://docs.astral.sh/uv/reference/cli/#uv-pip-install) may be useful.
|
||||
|
||||
114
docs/installation/quick_start.md
Normal file
114
docs/installation/quick_start.md
Normal file
@@ -0,0 +1,114 @@
|
||||
# Invoke Community Edition Quick Start
|
||||
|
||||
Welcome to Invoke! Follow these steps to install, update, and get started creating.
|
||||
|
||||
## Step 1: System Requirements
|
||||
|
||||
Invoke runs on Windows 10+, macOS 14+ and Linux (Ubuntu 20.04+ is well-tested).
|
||||
|
||||
Hardware requirements vary significantly depending on model and image output size. The requirements below are rough guidelines.
|
||||
|
||||
- All Apple Silicon (M1, M2, etc) Macs work, but 16GB+ memory is recommended.
|
||||
- AMD GPUs are supported on Linux only. The VRAM requirements are the same as Nvidia GPUs.
|
||||
|
||||
!!! info "Hardware Requirements (Windows/Linux)"
|
||||
|
||||
=== "SD1.5 - 512×512"
|
||||
|
||||
- GPU: Nvidia 10xx series or later, 4GB+ VRAM.
|
||||
- Memory: At least 8GB RAM.
|
||||
- Disk: 10GB for base installation plus 30GB for models.
|
||||
|
||||
=== "SDXL - 1024×1024"
|
||||
|
||||
- GPU: Nvidia 20xx series or later, 8GB+ VRAM.
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 100GB for models.
|
||||
|
||||
=== "FLUX - 1024×1024"
|
||||
|
||||
- GPU: Nvidia 20xx series or later, 10GB+ VRAM.
|
||||
- Memory: At least 32GB RAM.
|
||||
- Disk: 10GB for base installation plus 200GB for models.
|
||||
|
||||
More detail on system requirements can be found [here](./requirements.md).
|
||||
|
||||
## Step 2: Download
|
||||
|
||||
Download the most launcher for your operating system:
|
||||
|
||||
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
|
||||
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)
|
||||
- [Download for Linux](https://download.invoke.ai/Invoke%20Community%20Edition.AppImage)
|
||||
|
||||
## Step 3: Install or Update
|
||||
|
||||
Run the launcher you just downloaded, click **Install** and follow the instructions to get set up.
|
||||
|
||||
If you have an existing Invoke installation, you can select it and let the launcher manage the install. You'll be able to update or launch the installation.
|
||||
|
||||
!!! warning "Problem running the launcher on macOS"
|
||||
|
||||
macOS may not allow you to run the launcher. We are working to resolve this by signing the launcher executable. Until that is done, you can either use the [legacy scripts](./legacy_scripts.md) to install, or manually flag the launcher as safe:
|
||||
|
||||
- Open the **Invoke-Installer-mac-arm64.dmg** file.
|
||||
- Drag the launcher to **Applications**.
|
||||
- Open a terminal.
|
||||
- Run `xattr -cr /Applications/Invoke-Installer.app`.
|
||||
|
||||
You should now be able to run the launcher.
|
||||
|
||||
## Step 4: Launch
|
||||
|
||||
Once installed, click **Finish**, then **Launch** to start Invoke.
|
||||
|
||||
The very first run after an installation or update will take a few extra moments to get ready.
|
||||
|
||||
!!! tip "Server Mode"
|
||||
|
||||
The launcher runs Invoke as a desktop application. You can enable **Server Mode** in the launcher's settings to disable this and instead access the UI through your web browser.
|
||||
|
||||
## Step 5: Install Models
|
||||
|
||||
With Invoke started up, you'll need to install some models.
|
||||
|
||||
The quickest way to get started is to install a **Starter Model** bundle. If you already have a model collection, Invoke can use it.
|
||||
|
||||
!!! info "Install Models"
|
||||
|
||||
=== "Install a Starter Model bundle"
|
||||
|
||||
1. Go to the **Models** tab.
|
||||
2. Click **Starter Models** on the right.
|
||||
3. Click one of the bundles to install its models. Refer to the [system requirements](#step-1-confirm-system-requirements) if you're unsure which model architecture will work for your system.
|
||||
|
||||
=== "Use my model collection"
|
||||
|
||||
4. Go to the **Models** tab.
|
||||
5. Click **Scan Folder** on the right.
|
||||
6. Paste the path to your models collection and click **Scan Folder**.
|
||||
7. With **In-place install** enabled, Invoke will leave the model files where they are. If you disable this, **Invoke will move the models into its own folders**.
|
||||
|
||||
You’re now ready to start creating!
|
||||
|
||||
## Step 6: Learn the Basics
|
||||
|
||||
We recommend watching our [Getting Started Playlist](https://www.youtube.com/playlist?list=PLvWK1Kc8iXGrQy8r9TYg6QdUuJ5MMx-ZO). It covers essential features and workflows, including:
|
||||
|
||||
- Generating your first image.
|
||||
- Using control layers and reference guides.
|
||||
- Refining images with advanced workflows.
|
||||
|
||||
## Other Installation Methods
|
||||
|
||||
- You can install the Invoke application as a python package. See our [manual install](./manual.md) docs.
|
||||
- You can run Invoke with docker. See our [docker install](./docker.md) docs.
|
||||
- You can still use our legacy scripts to install and run Invoke. See the [legacy scripts](./legacy_scripts.md) docs.
|
||||
|
||||
## Need Help?
|
||||
|
||||
- Visit our [Support Portal](https://support.invoke.ai).
|
||||
- Watch the [Getting Started Playlist](https://www.youtube.com/playlist?list=PLvWK1Kc8iXGrQy8r9TYg6QdUuJ5MMx-ZO).
|
||||
- Join the conversation on [Discord][discord link].
|
||||
|
||||
[discord link]: https://discord.gg/ZmtBAhwWhy
|
||||
@@ -1,90 +1,33 @@
|
||||
# Requirements
|
||||
|
||||
## GPU
|
||||
Invoke runs on Windows 10+, macOS 14+ and Linux (Ubuntu 20.04+ is well-tested).
|
||||
|
||||
!!! warning "Problematic Nvidia GPUs"
|
||||
## Hardware
|
||||
|
||||
We do not recommend these GPUs. They cannot operate with half precision, but have insufficient VRAM to generate 512x512 images at full precision.
|
||||
Hardware requirements vary significantly depending on model and image output size. The requirements below are rough guidelines.
|
||||
|
||||
- NVIDIA 10xx series cards such as the 1080 TI
|
||||
- GTX 1650 series cards
|
||||
- GTX 1660 series cards
|
||||
- All Apple Silicon (M1, M2, etc) Macs work, but 16GB+ memory is recommended.
|
||||
- AMD GPUs are supported on Linux only. The VRAM requirements are the same as Nvidia GPUs.
|
||||
|
||||
Invoke runs best with a dedicated GPU, but will fall back to running on CPU, albeit much slower. You'll need a beefier GPU for SDXL.
|
||||
!!! info "Hardware Requirements (Windows/Linux)"
|
||||
|
||||
!!! example "Stable Diffusion 1.5"
|
||||
=== "SD1.5 - 512×512"
|
||||
|
||||
=== "Nvidia"
|
||||
- GPU: Nvidia 10xx series or later, 4GB+ VRAM.
|
||||
- Memory: At least 8GB RAM.
|
||||
- Disk: 10GB for base installation plus 30GB for models.
|
||||
|
||||
```
|
||||
Any GPU with at least 4GB VRAM.
|
||||
```
|
||||
=== "SDXL - 1024×1024"
|
||||
|
||||
=== "AMD"
|
||||
- GPU: Nvidia 20xx series or later, 8GB+ VRAM.
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 100GB for models.
|
||||
|
||||
```
|
||||
Any GPU with at least 4GB VRAM. Linux only.
|
||||
```
|
||||
=== "FLUX - 1024×1024"
|
||||
|
||||
=== "Mac"
|
||||
|
||||
```
|
||||
Any Apple Silicon Mac with at least 8GB memory.
|
||||
```
|
||||
|
||||
!!! example "Stable Diffusion XL"
|
||||
|
||||
=== "Nvidia"
|
||||
|
||||
```
|
||||
Any GPU with at least 8GB VRAM.
|
||||
```
|
||||
|
||||
=== "AMD"
|
||||
|
||||
```
|
||||
Any GPU with at least 16GB VRAM. Linux only.
|
||||
```
|
||||
|
||||
=== "Mac"
|
||||
|
||||
```
|
||||
Any Apple Silicon Mac with at least 16GB memory.
|
||||
```
|
||||
|
||||
## RAM
|
||||
|
||||
At least 12GB of RAM.
|
||||
|
||||
## Disk
|
||||
|
||||
SSDs will, of course, offer the best performance.
|
||||
|
||||
The base application disk usage depends on the torch backend.
|
||||
|
||||
!!! example "Disk"
|
||||
|
||||
=== "Nvidia (CUDA)"
|
||||
|
||||
```
|
||||
~6.5GB
|
||||
```
|
||||
|
||||
=== "AMD (ROCm)"
|
||||
|
||||
```
|
||||
~12GB
|
||||
```
|
||||
|
||||
=== "Mac (MPS)"
|
||||
|
||||
```
|
||||
~3.5GB
|
||||
```
|
||||
|
||||
You'll need to set aside some space for images, depending on how much you generate. A couple GB is enough to get started.
|
||||
|
||||
You'll need a good chunk of space for models. Even if you only install the most popular models and the usual support models (ControlNet, IP Adapter ,etc), you will quickly hit 50GB of models.
|
||||
- GPU: Nvidia 20xx series or later, 10GB+ VRAM.
|
||||
- Memory: At least 32GB RAM.
|
||||
- Disk: 10GB for base installation plus 200GB for models.
|
||||
|
||||
!!! info "`tmpfs` on Linux"
|
||||
|
||||
@@ -92,26 +35,32 @@ You'll need a good chunk of space for models. Even if you only install the most
|
||||
|
||||
## Python
|
||||
|
||||
!!! tip "The launcher installs python for you"
|
||||
|
||||
You don't need to do this if you are installing with the [Invoke Launcher](./quick_start.md).
|
||||
|
||||
Invoke requires python 3.10 or 3.11. If you don't already have one of these versions installed, we suggest installing 3.11, as it will be supported for longer.
|
||||
|
||||
Check that your system has an up-to-date Python installed by running `python --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
|
||||
Check that your system has an up-to-date Python installed by running `python3 --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
|
||||
|
||||
<h3>Installing Python (Windows)</h3>
|
||||
!!! info "Installing Python"
|
||||
|
||||
- Install python 3.11 with [an official installer].
|
||||
- The installer includes an option to add python to your PATH. Be sure to enable this. If you missed it, re-run the installer, choose to modify an existing installation, and tick that checkbox.
|
||||
- You may need to install [Microsoft Visual C++ Redistributable].
|
||||
=== "Windows"
|
||||
|
||||
<h3>Installing Python (macOS)</h3>
|
||||
- Install python 3.11 with [an official installer].
|
||||
- The installer includes an option to add python to your PATH. Be sure to enable this. If you missed it, re-run the installer, choose to modify an existing installation, and tick that checkbox.
|
||||
- You may need to install [Microsoft Visual C++ Redistributable].
|
||||
|
||||
- Install python 3.11 with [an official installer].
|
||||
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
|
||||
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
|
||||
=== "macOS"
|
||||
|
||||
<h3>Installing Python (Linux)</h3>
|
||||
- Install python 3.11 with [an official installer].
|
||||
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
|
||||
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
|
||||
|
||||
- Follow the [linux install instructions], being sure to install python 3.11.
|
||||
- You'll need to install `libglib2.0-0` and `libgl1-mesa-glx` for OpenCV to work. For example, on a Debian system: `sudo apt update && sudo apt install -y libglib2.0-0 libgl1-mesa-glx`
|
||||
=== "Linux"
|
||||
|
||||
- Installing python varies depending on your system. On Ubuntu, you can use the [deadsnakes PPA](https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa).
|
||||
- You'll need to install `libglib2.0-0` and `libgl1-mesa-glx` for OpenCV to work. For example, on a Debian system: `sudo apt update && sudo apt install -y libglib2.0-0 libgl1-mesa-glx`
|
||||
|
||||
## Drivers
|
||||
|
||||
@@ -175,7 +124,4 @@ An alternative to installing ROCm locally is to use a [ROCm docker container] to
|
||||
[ROCm Documentation]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html
|
||||
[cuDNN support matrix]: https://docs.nvidia.com/deeplearning/cudnn/support-matrix/index.html
|
||||
[Nvidia Container Runtime]: https://developer.nvidia.com/container-runtime
|
||||
[linux install instructions]: https://docs.python-guide.org/starting/install3/linux/
|
||||
[Microsoft Visual C++ Redistributable]: https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170
|
||||
[an official installer]: https://www.python.org/downloads/
|
||||
[CUDA Toolkit Downloads]: https://developer.nvidia.com/cuda-downloads
|
||||
|
||||
@@ -49,6 +49,7 @@ To use a community workflow, download the `.json` node graph file and load it in
|
||||
+ [BriaAI Background Remove](#briaai-remove-background)
|
||||
+ [Remove Background](#remove-background)
|
||||
+ [Retroize](#retroize)
|
||||
+ [Stereogram](#stereogram-nodes)
|
||||
+ [Size Stepper Nodes](#size-stepper-nodes)
|
||||
+ [Simple Skin Detection](#simple-skin-detection)
|
||||
+ [Text font to Image](#text-font-to-image)
|
||||
@@ -526,6 +527,16 @@ View:
|
||||
|
||||
<img src="https://github.com/Ar7ific1al/InvokeAI_nodes_retroize/assets/2306586/de8b4fa6-324c-4c2d-b36c-297600c73974" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Stereogram Nodes
|
||||
|
||||
**Description:** A set of custom nodes for InvokeAI to create cross-view or parallel-view stereograms. Stereograms are 2D images that, when viewed properly, reveal a 3D scene. Check out [r/crossview](https://www.reddit.com/r/CrossView/) for tutorials.
|
||||
|
||||
**Node Link:** https://github.com/simonfuhrmann/invokeai-stereo
|
||||
|
||||
**Example Workflow and Output**
|
||||
</br><img src="https://github.com/simonfuhrmann/invokeai-stereo/blob/main/docs/example_promo_03.jpg" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Simple Skin Detection
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
@@ -59,11 +59,32 @@ logger.info(f"Using torch device: {torch_device_name}")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
# We may change the port if the default is in use, this global variable is used to store the port so that we can log
|
||||
# the correct port when the server starts in the lifespan handler.
|
||||
port = app_config.port
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Add startup event to load dependencies
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
|
||||
|
||||
# Log the server address when it starts - in case the network log level is not high enough to see the startup log
|
||||
proto = "https" if app_config.ssl_certfile else "http"
|
||||
msg = f"Invoke running on {proto}://{app_config.host}:{port} (Press CTRL+C to quit)"
|
||||
|
||||
# Logging this way ignores the logger's log level and _always_ logs the message
|
||||
record = logger.makeRecord(
|
||||
name=logger.name,
|
||||
level=logging.INFO,
|
||||
fn="",
|
||||
lno=0,
|
||||
msg=msg,
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
logger.handle(record)
|
||||
|
||||
yield
|
||||
# Shut down threads
|
||||
ApiDependencies.shutdown()
|
||||
@@ -206,6 +227,7 @@ def invoke_api() -> None:
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
|
||||
global port
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||
@@ -217,18 +239,17 @@ def invoke_api() -> None:
|
||||
host=app_config.host,
|
||||
port=port,
|
||||
loop="asyncio",
|
||||
log_level=app_config.log_level,
|
||||
log_level=app_config.log_level_network,
|
||||
ssl_certfile=app_config.ssl_certfile,
|
||||
ssl_keyfile=app_config.ssl_keyfile,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
for logname in ["uvicorn.access", "uvicorn"]:
|
||||
log = InvokeAILogger.get_logger(logname)
|
||||
log.handlers.clear()
|
||||
for ch in logger.handlers:
|
||||
log.addHandler(ch)
|
||||
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
|
||||
uvicorn_logger.handlers.clear()
|
||||
for hdlr in logger.handlers:
|
||||
uvicorn_logger.addHandler(hdlr)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
@@ -19,9 +19,9 @@ from invokeai.app.invocations.model import CLIPField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
@@ -66,10 +66,10 @@ class CompelInvocation(BaseInvocation):
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@@ -82,10 +82,11 @@ class CompelInvocation(BaseInvocation):
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_te_",
|
||||
dtype=text_encoder.dtype,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
@@ -162,11 +163,11 @@ class SDXLPromptInvocationBase:
|
||||
c_pooled = None
|
||||
return c, c_pooled
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
assert isinstance(lora_model, ModelPatchRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@@ -179,10 +180,11 @@ class SDXLPromptInvocationBase:
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
text_encoder,
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
dtype=text_encoder.dtype,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
|
||||
@@ -6,7 +6,6 @@ from PIL import Image
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
|
||||
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
@@ -29,11 +28,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||
fp32: bool = InputField(
|
||||
default=DEFAULT_PRECISION == torch.float32,
|
||||
description=FieldDescriptions.fp32,
|
||||
ui_order=4,
|
||||
)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32, ui_order=4)
|
||||
|
||||
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
|
||||
if mask_image.mode != "L":
|
||||
|
||||
@@ -7,7 +7,6 @@ from PIL import Image, ImageFilter
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
@@ -76,11 +75,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
ui_order=7,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
|
||||
fp32: bool = InputField(
|
||||
default=DEFAULT_PRECISION == torch.float32,
|
||||
description=FieldDescriptions.fp32,
|
||||
ui_order=9,
|
||||
)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32, ui_order=9)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
|
||||
@@ -37,10 +37,10 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
@@ -987,10 +987,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@@ -1003,10 +1003,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=unet,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_unet_",
|
||||
dtype=unet.dtype,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
|
||||
@@ -56,6 +56,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
ControlLoRAModel = "ControlLoRAModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@@ -143,6 +144,7 @@ class FieldDescriptions:
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
control_lora_model = "Control LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
|
||||
49
invokeai/app/invocations/flux_control_lora_loader.py
Normal file
49
invokeai/app/invocations/flux_control_lora_loader.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ControlLoRAField, ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_control_lora_loader_output")
|
||||
class FluxControlLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""Flux Control LoRA Loader Output"""
|
||||
|
||||
control_lora: ControlLoRAField = OutputField(
|
||||
title="Flux Control LoRA", description="Control LoRAs to apply on model loading", default=None
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_control_lora_loader",
|
||||
title="Flux Control LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxControlLoRALoaderInvocation(BaseInvocation):
|
||||
"""LoRA model and Image to use with FLUX transformer generation."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
|
||||
)
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
weight: float = InputField(description="The weight of the LoRA.", default=1.0)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
|
||||
if not context.models.exists(self.lora.key):
|
||||
raise ValueError(f"Unknown lora: {self.lora.key}!")
|
||||
|
||||
return FluxControlLoRALoaderOutput(
|
||||
control_lora=ControlLoRAField(
|
||||
lora=self.lora,
|
||||
img=self.image,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -1,10 +1,12 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
from typing import Callable, Iterator, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
@@ -21,8 +23,9 @@ from invokeai.app.invocations.fields import (
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
@@ -44,10 +47,10 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -89,6 +92,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
control_lora: Optional[ControlLoRAField] = InputField(
|
||||
description=FieldDescriptions.control_lora_model, input=Input.Connection, title="Control LoRA", default=None
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
@@ -194,7 +200,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
is_schnell = "schnell" in getattr(transformer_info.config, "config_path", "")
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
timesteps = get_schedule(
|
||||
@@ -234,6 +240,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if len(timesteps) <= 1:
|
||||
return x
|
||||
|
||||
if is_schnell and self.control_lora:
|
||||
raise ValueError("Control LoRAs cannot be used with FLUX Schnell")
|
||||
|
||||
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
|
||||
img_cond = self._prep_structural_control_img_cond(context)
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
@@ -241,6 +253,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
img_cond = pack(img_cond) if img_cond is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
@@ -291,36 +304,33 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
config = transformer_info.config
|
||||
assert config is not None
|
||||
|
||||
# Apply LoRA models to the transformer.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
# Determine if the model is quantized.
|
||||
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
|
||||
# slower inference than direct patching, but is agnostic to the quantization format.
|
||||
if config.format in [ModelFormat.Checkpoint]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
model_is_quantized = False
|
||||
elif config.format in [
|
||||
ModelFormat.BnbQuantizedLlmInt8b,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
ModelFormat.GGUFQuantized,
|
||||
]:
|
||||
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
|
||||
# than directly patching the weights, but is agnostic to the quantization format.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_sidecar_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
)
|
||||
model_is_quantized = True
|
||||
else:
|
||||
raise ValueError(f"Unsupported model format: {config.format}")
|
||||
|
||||
# Apply LoRA models to the transformer.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
force_sidecar_patching=model_is_quantized,
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare IP-Adapter extensions.
|
||||
pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions(
|
||||
pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds,
|
||||
@@ -345,6 +355,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
controlnet_extensions=controlnet_extensions,
|
||||
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
img_cond=img_cond,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
@@ -575,6 +586,29 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return controlnet_extensions
|
||||
|
||||
def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor | None:
|
||||
if self.control_lora is None:
|
||||
return None
|
||||
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("controlnet_vae must be set when using a FLUX Control LoRA.")
|
||||
|
||||
# Load the conditioning image and resize it to the target image size.
|
||||
cond_img = context.images.get_pil(self.control_lora.img.image_name)
|
||||
cond_img = cond_img.convert("RGB")
|
||||
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
|
||||
cond_img = np.array(cond_img)
|
||||
|
||||
# Normalize the conditioning image to the range [-1, 1].
|
||||
# This normalization is based on the original implementations here:
|
||||
# https://github.com/black-forest-labs/flux/blob/805da8571a0b49b6d4043950bd266a65328c243b/src/flux/modules/image_embedders.py#L34
|
||||
# https://github.com/black-forest-labs/flux/blob/805da8571a0b49b6d4043950bd266a65328c243b/src/flux/modules/image_embedders.py#L60
|
||||
img_cond = torch.from_numpy(cond_img).float() / 127.5 - 1.0
|
||||
img_cond = einops.rearrange(img_cond, "h w c -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
|
||||
|
||||
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
|
||||
if self.ip_adapter is None:
|
||||
return []
|
||||
@@ -681,10 +715,15 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
|
||||
if self.control_lora:
|
||||
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
|
||||
# applied last.
|
||||
loras.append(self.control_lora)
|
||||
for lora in loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
|
||||
@@ -17,10 +17,10 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
|
||||
|
||||
@@ -111,10 +111,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
dtype=clip_text_encoder.dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
@@ -130,9 +131,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return pooled_prompt_embeds
|
||||
|
||||
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -13,7 +13,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@@ -49,7 +49,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
|
||||
# offer a way to directly set None values.
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(
|
||||
|
||||
@@ -12,7 +12,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -51,7 +51,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
|
||||
# offer a way to directly set None values.
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
@@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@@ -65,11 +65,6 @@ class CLIPField(BaseModel):
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
@@ -80,6 +75,15 @@ class VAEField(BaseModel):
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
|
||||
class ControlLoRAField(LoRAField):
|
||||
img: ImageField = Field(description="Image to use in structural conditioning")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
@invocation_output("unet_output")
|
||||
class UNetOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a UNet field."""
|
||||
|
||||
@@ -16,10 +16,10 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import SD3ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
|
||||
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
||||
@@ -150,10 +150,11 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context, clip_model),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
dtype=clip_text_encoder.dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
@@ -193,9 +194,9 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
|
||||
def _clip_lora_iterator(
|
||||
self, context: InvocationContext, clip_model: CLIPField
|
||||
) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in clip_model.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
MultiDiffusionPipeline,
|
||||
@@ -194,10 +194,10 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -207,7 +207,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info as unet,
|
||||
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
|
||||
),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
@@ -57,8 +57,10 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
|
||||
|
||||
class BoardChanges(BaseModel, extra="forbid"):
|
||||
board_name: Optional[str] = Field(default=None, description="The board's new name.")
|
||||
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
|
||||
board_name: Optional[str] = Field(default=None, description="The board's new name.", max_length=255)
|
||||
cover_image_name: Optional[str] = Field(
|
||||
default=None, description="The name of the board's new cover image.", max_length=255
|
||||
)
|
||||
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
|
||||
|
||||
|
||||
|
||||
@@ -97,6 +97,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
|
||||
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
|
||||
log_sql: Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.
|
||||
log_level_network: Log level for network-related messages. 'info' and 'debug' are very verbose.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
|
||||
use_memory_db: Use in-memory database. Useful for development.
|
||||
dev_reload: Automatically reload when Python sources are changed. Does not reload node definitions.
|
||||
profile_graphs: Enable graph profiling using `cProfile`.
|
||||
@@ -163,6 +164,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
log_format: LOG_FORMAT = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.')
|
||||
log_level: LOG_LEVEL = Field(default="info", description="Emit logging messages at this level or higher.")
|
||||
log_sql: bool = Field(default=False, description="Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.")
|
||||
log_level_network: LOG_LEVEL = Field(default='warning', description="Log level for network-related messages. 'info' and 'debug' are very verbose.")
|
||||
|
||||
# Development
|
||||
use_memory_db: bool = Field(default=False, description="Use in-memory database. Useful for development.")
|
||||
|
||||
@@ -8,7 +8,7 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Any, Dict, List, Literal, Optional, Set
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
@@ -28,11 +28,13 @@ from invokeai.app.services.download.download_base import (
|
||||
ServiceInactiveException,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
# Maximum number of bytes to download during each call to requests.iter_content()
|
||||
DOWNLOAD_CHUNK_SIZE = 100000
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .events_base import EventServiceBase # noqa F401
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.registry.payload_schema import registry as payload_schema
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
@@ -18,7 +19,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
|
||||
|
||||
class EventBase(BaseModel):
|
||||
@@ -422,7 +423,7 @@ class ModelInstallDownloadStartedEvent(ModelEventBase):
|
||||
__event_name__ = "model_install_download_started"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
|
||||
local_path: str = Field(description="Where model is downloading to")
|
||||
bytes: int = Field(description="Number of bytes downloaded so far")
|
||||
total_bytes: int = Field(description="Total size of download, including all files")
|
||||
@@ -443,7 +444,7 @@ class ModelInstallDownloadStartedEvent(ModelEventBase):
|
||||
]
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
source=job.source,
|
||||
local_path=job.local_path.as_posix(),
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
@@ -458,7 +459,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||
__event_name__ = "model_install_download_progress"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
|
||||
local_path: str = Field(description="Where model is downloading to")
|
||||
bytes: int = Field(description="Number of bytes downloaded so far")
|
||||
total_bytes: int = Field(description="Total size of download, including all files")
|
||||
@@ -479,7 +480,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||
]
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
source=job.source,
|
||||
local_path=job.local_path.as_posix(),
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
@@ -494,11 +495,11 @@ class ModelInstallDownloadsCompleteEvent(ModelEventBase):
|
||||
__event_name__ = "model_install_downloads_complete"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
|
||||
return cls(id=job.id, source=str(job.source))
|
||||
return cls(id=job.id, source=job.source)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
@@ -508,11 +509,11 @@ class ModelInstallStartedEvent(ModelEventBase):
|
||||
__event_name__ = "model_install_started"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
|
||||
return cls(id=job.id, source=str(job.source))
|
||||
return cls(id=job.id, source=job.source)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
@@ -522,14 +523,14 @@ class ModelInstallCompleteEvent(ModelEventBase):
|
||||
__event_name__ = "model_install_complete"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
|
||||
key: str = Field(description="Model config record key")
|
||||
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
|
||||
assert job.config_out is not None
|
||||
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
|
||||
return cls(id=job.id, source=job.source, key=(job.config_out.key), total_bytes=job.total_bytes)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
@@ -539,11 +540,11 @@ class ModelInstallCancelledEvent(ModelEventBase):
|
||||
__event_name__ = "model_install_cancelled"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
|
||||
return cls(id=job.id, source=str(job.source))
|
||||
return cls(id=job.id, source=job.source)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
@@ -553,7 +554,7 @@ class ModelInstallErrorEvent(ModelEventBase):
|
||||
__event_name__ = "model_install_error"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
|
||||
error_type: str = Field(description="The name of the exception")
|
||||
error: str = Field(description="A text description of the exception")
|
||||
|
||||
@@ -561,7 +562,7 @@ class ModelInstallErrorEvent(ModelEventBase):
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
|
||||
assert job.error_type is not None
|
||||
assert job.error is not None
|
||||
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
|
||||
return cls(id=job.id, source=job.source, error_type=job.error_type, error=job.error)
|
||||
|
||||
|
||||
class BulkDownloadEventBase(EventBase):
|
||||
|
||||
@@ -20,7 +20,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
||||
NodeExecutionStatsSummary,
|
||||
)
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load.model_cache import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
|
||||
# Size of 1GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
@@ -3,18 +3,20 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
"""Abstract base class for InvokeAI model installation."""
|
||||
|
||||
@@ -9,7 +9,7 @@ from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@@ -20,7 +20,6 @@ from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_install.model_install_common import (
|
||||
@@ -57,6 +56,10 @@ from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.util import slugify
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
TMPDIR_PREFIX = "tmpinstall_"
|
||||
|
||||
|
||||
@@ -438,9 +441,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
source_stripped = source.strip('"')
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
if Path(source_stripped).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source_stripped))
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
|
||||
|
||||
class ModelLoadServiceBase(ABC):
|
||||
@@ -24,7 +24,7 @@ class ModelLoadServiceBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load import (
|
||||
ModelLoaderRegistry,
|
||||
ModelLoaderRegistryBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -30,7 +30,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
||||
):
|
||||
"""Initialize the model load service."""
|
||||
@@ -45,7 +45,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@@ -78,9 +78,8 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
try:
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
@@ -109,5 +108,5 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
)
|
||||
assert loader is not None
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
self._ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
|
||||
@@ -16,7 +16,8 @@ from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBas
|
||||
from invokeai.app.services.model_load.model_load_default import ModelLoadService
|
||||
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
@@ -439,7 +439,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
self._invoker.services.logger.info(
|
||||
f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}"
|
||||
)
|
||||
cancel_event.clear()
|
||||
|
||||
# Run the graph
|
||||
|
||||
@@ -30,6 +30,8 @@ def denoise(
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
||||
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
# extra img tokens
|
||||
img_cond: torch.Tensor | None,
|
||||
):
|
||||
# step 0 is the initial state
|
||||
total_steps = len(timesteps) - 1
|
||||
@@ -69,9 +71,9 @@ def denoise(
|
||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
|
||||
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
|
||||
pred = model(
|
||||
img=img,
|
||||
img=pred_img,
|
||||
img_ids=img_ids,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@@ -35,6 +36,7 @@ class FluxParams:
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
out_channels: Optional[int] = None
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
@@ -47,7 +49,7 @@ class Flux(nn.Module):
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
self.out_channels = params.out_channels or self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]
|
||||
@@ -1,302 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
|
||||
ConcatenatedLoRALinearSidecarLayer,
|
||||
)
|
||||
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
|
||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class LoRAPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Apply one or more LoRA patches to a model within a context manager.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
|
||||
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
|
||||
all at once.
|
||||
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
|
||||
cached_weights (Optional[Dict[str, torch.Tensor]], optional): Read-only copy of the model's state dict in
|
||||
CPU RAM, for efficient unpatching purposes.
|
||||
"""
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher.apply_lora_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del patch
|
||||
|
||||
yield
|
||||
finally:
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def apply_lora_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""Apply a single LoRA patch to a model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
prefix (str): A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
patch (LoRAModelRaw): The LoRA model to patch in.
|
||||
patch_weight (float): The weight of the LoRA patch.
|
||||
original_weights (OriginalWeightsStorage): Storage for the original weights of the model, for unpatching.
|
||||
"""
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LoRAPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
layer_scale = layer.scale()
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_sidecar_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
|
||||
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
|
||||
quantization format.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
|
||||
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
|
||||
all at once.
|
||||
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
|
||||
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
|
||||
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
|
||||
different from their compute dtype.
|
||||
"""
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher._apply_lora_sidecar_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
# Restore original modules.
|
||||
# Note: This logic assumes no nested modules in original_modules.
|
||||
for module_key, orig_module in original_modules.items():
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
||||
parent_module = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
def _apply_lora_sidecar_patch(
|
||||
model: torch.nn.Module,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
prefix: str,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA sidecar patch to a model."""
|
||||
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LoRAPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# Initialize the LoRA sidecar layer.
|
||||
lora_sidecar_layer = LoRAPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
|
||||
|
||||
# Replace the original module with a LoRASidecarModule if it has not already been done.
|
||||
if module_key in original_modules:
|
||||
# The module has already been patched with a LoRASidecarModule. Append to it.
|
||||
assert isinstance(module, LoRASidecarModule)
|
||||
lora_sidecar_module = module
|
||||
else:
|
||||
# The module has not yet been patched with a LoRASidecarModule. Create one.
|
||||
lora_sidecar_module = LoRASidecarModule(module, [])
|
||||
original_modules[module_key] = module
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
||||
module_parent = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
|
||||
|
||||
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
|
||||
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
|
||||
lora_sidecar_layer.to(device=lora_sidecar_module.orig_module.weight.device, dtype=dtype)
|
||||
|
||||
# Add the LoRA sidecar layer to the LoRASidecarModule.
|
||||
lora_sidecar_module.add_lora_layer(lora_sidecar_layer)
|
||||
|
||||
@staticmethod
|
||||
def _split_parent_key(module_key: str) -> tuple[str, str]:
|
||||
"""Split a module key into its parent key and module name.
|
||||
|
||||
Args:
|
||||
module_key (str): The module key to split.
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing the parent key and module name.
|
||||
"""
|
||||
split_key = module_key.rsplit(".", 1)
|
||||
if len(split_key) == 2:
|
||||
return tuple(split_key)
|
||||
elif len(split_key) == 1:
|
||||
return "", split_key[0]
|
||||
else:
|
||||
raise ValueError(f"Invalid module key: {module_key}")
|
||||
|
||||
@staticmethod
|
||||
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
|
||||
# TODO(ryand): Add support for more original layer types and LoRA layer types.
|
||||
if isinstance(orig_layer, torch.nn.Linear) or (
|
||||
isinstance(orig_layer, LoRASidecarModule) and isinstance(orig_layer.orig_module, torch.nn.Linear)
|
||||
):
|
||||
if isinstance(lora_layer, LoRALayer):
|
||||
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
|
||||
elif isinstance(lora_layer, ConcatenatedLoRALayer):
|
||||
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
|
||||
|
||||
@staticmethod
|
||||
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
|
||||
try:
|
||||
submodule_index = int(module_name)
|
||||
# If the module name is an integer, then we use the __setitem__ method to set the submodule.
|
||||
parent_module[submodule_index] = submodule # type: ignore
|
||||
except ValueError:
|
||||
# If the module name is not an integer, then we use the setattr method to set the submodule.
|
||||
setattr(parent_module, module_name, submodule)
|
||||
|
||||
@staticmethod
|
||||
def _get_submodule(
|
||||
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
|
||||
) -> tuple[str, torch.nn.Module]:
|
||||
"""Get the submodule corresponding to the given layer key.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to search.
|
||||
layer_key (str): The layer key to search for.
|
||||
layer_key_is_flattened (bool): Whether the layer key is flattened. If flattened, then all '.' have been
|
||||
replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed
|
||||
directly without searching, but some legacy code still uses flattened keys.
|
||||
|
||||
Returns:
|
||||
tuple[str, torch.nn.Module]: A tuple containing the module key and the submodule.
|
||||
"""
|
||||
if not layer_key_is_flattened:
|
||||
return layer_key, model.get_submodule(layer_key)
|
||||
|
||||
# Handle flattened keys.
|
||||
assert "." not in layer_key
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = layer_key.split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return module_key, module
|
||||
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
|
||||
|
||||
class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
concatenated_lora_layer: ConcatenatedLoRALayer,
|
||||
weight: float,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._concatenated_lora_layer = concatenated_lora_layer
|
||||
self._weight = weight
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in self._concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= self._weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert self._concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._concatenated_lora_layer.to(device=device, dtype=dtype)
|
||||
return self
|
||||
@@ -1,27 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
class LoRALinearSidecarLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
lora_layer: LoRALayer,
|
||||
weight: float,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._lora_layer = lora_layer
|
||||
self._weight = weight
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.down)
|
||||
if self._lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias)
|
||||
x *= self._weight * self._lora_layer.scale()
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._lora_layer.to(device=device, dtype=dtype)
|
||||
return self
|
||||
@@ -1,24 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class LoRASidecarModule(torch.nn.Module):
|
||||
"""A LoRA sidecar module that wraps an original module and adds LoRA layers to it."""
|
||||
|
||||
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
|
||||
super().__init__()
|
||||
self.orig_module = orig_module
|
||||
self._lora_layers = lora_layers
|
||||
|
||||
def add_lora_layer(self, lora_layer: torch.nn.Module):
|
||||
self._lora_layers.append(lora_layer)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
x = self.orig_module(input)
|
||||
for lora_layer in self._lora_layers:
|
||||
x += lora_layer(input)
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._orig_module.to(device=device, dtype=dtype)
|
||||
for lora_layer in self._lora_layers:
|
||||
lora_layer.to(device=device, dtype=dtype)
|
||||
@@ -67,6 +67,7 @@ class ModelType(str, Enum):
|
||||
Main = "main"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
ControlLoRa = "control_lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
@@ -273,6 +274,36 @@ class LoRALyCORISConfig(LoRAConfigBase):
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class ControlAdapterConfigBase(BaseModel):
|
||||
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class ControlLoRALyCORISConfig(ModelConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for Control LoRA models."""
|
||||
|
||||
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class ControlLoRADiffusersConfig(ModelConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for Control LoRA models."""
|
||||
|
||||
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class LoRADiffusersConfig(LoRAConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
@@ -304,12 +335,6 @@ class VAEDiffusersConfig(ModelConfigBase):
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class ControlAdapterConfigBase(BaseModel):
|
||||
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
@@ -535,6 +560,8 @@ AnyModelConfig = Annotated[
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
|
||||
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
||||
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
|
||||
# This registers the subclasses that implement loaders of specific model types
|
||||
|
||||
@@ -5,7 +5,6 @@ Base class for model loading in InvokeAI.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
@@ -18,19 +17,17 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModelWithoutConfig:
|
||||
"""
|
||||
Context manager object that mediates transfer from RAM<->VRAM.
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM.
|
||||
|
||||
This is a context manager object that has two distinct APIs:
|
||||
|
||||
1. Older API (deprecated):
|
||||
Use the LoadedModel object directly as a context manager.
|
||||
It will move the model into VRAM (on CUDA devices), and
|
||||
Use the LoadedModel object directly as a context manager. It will move the model into VRAM (on CUDA devices), and
|
||||
return the model in a form suitable for passing to torch.
|
||||
Example:
|
||||
```
|
||||
@@ -40,13 +37,9 @@ class LoadedModelWithoutConfig:
|
||||
```
|
||||
|
||||
2. Newer API (recommended):
|
||||
Call the LoadedModel's `model_on_device()` method in a
|
||||
context. It returns a tuple consisting of a copy of
|
||||
the model's state dict in CPU RAM followed by a copy
|
||||
of the model in VRAM. The state dict is provided to allow
|
||||
LoRAs and other model patchers to return the model to
|
||||
its unpatched state without expensive copy and restore
|
||||
operations.
|
||||
Call the LoadedModel's `model_on_device()` method in a context. It returns a tuple consisting of a copy of the
|
||||
model's state dict in CPU RAM followed by a copy of the model in VRAM. The state dict is provided to allow LoRAs and
|
||||
other model patchers to return the model to its unpatched state without expensive copy and restore operations.
|
||||
|
||||
Example:
|
||||
```
|
||||
@@ -55,43 +48,42 @@ class LoadedModelWithoutConfig:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
The state_dict should be treated as a read-only object and
|
||||
never modified. Also be aware that some loadable models do
|
||||
not have a state_dict, in which case this value will be None.
|
||||
The state_dict should be treated as a read-only object and never modified. Also be aware that some loadable models
|
||||
do not have a state_dict, in which case this value will be None.
|
||||
"""
|
||||
|
||||
_locker: ModelLockerBase
|
||||
def __init__(self, cache_record: CacheRecord, cache: ModelCache):
|
||||
self._cache_record = cache_record
|
||||
self._cache = cache
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self._locker.lock()
|
||||
self._cache.lock(self._cache_record.key)
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
self._locker.unlock()
|
||||
self._cache.unlock(self._cache_record.key)
|
||||
|
||||
@contextmanager
|
||||
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
|
||||
locked_model = self._locker.lock()
|
||||
self._cache.lock(self._cache_record.key)
|
||||
try:
|
||||
state_dict = self._locker.get_state_dict()
|
||||
yield (state_dict, locked_model)
|
||||
yield (self._cache_record.state_dict, self._cache_record.model)
|
||||
finally:
|
||||
self._locker.unlock()
|
||||
self._cache.unlock(self._cache_record.key)
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without locking it."""
|
||||
return self._locker.model
|
||||
return self._cache_record.model
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: Optional[AnyModelConfig] = None
|
||||
def __init__(self, config: Optional[AnyModelConfig], cache_record: CacheRecord, cache: ModelCache):
|
||||
super().__init__(cache_record=cache_record, cache=cache)
|
||||
self.config = config
|
||||
|
||||
|
||||
# TODO(MM2):
|
||||
@@ -110,7 +102,7 @@ class ModelLoaderBase(ABC):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
pass
|
||||
@@ -138,6 +130,6 @@ class ModelLoaderBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
pass
|
||||
|
||||
@@ -14,7 +14,8 @@ from invokeai.backend.model_manager import (
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -28,7 +29,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
self._app_config = app_config
|
||||
@@ -54,11 +55,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
with skip_torch_weight_init():
|
||||
locker = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
cache_record = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@@ -66,10 +67,10 @@ class ModelLoader(ModelLoaderBase):
|
||||
model_base = self._app_config.models_path
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord:
|
||||
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
|
||||
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
@@ -78,16 +79,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
config.key,
|
||||
submodel_type=submodel_type,
|
||||
get_model_cache_key(config.key, submodel_type),
|
||||
model=loaded_model,
|
||||
)
|
||||
|
||||
return self._ram_cache.get(
|
||||
key=config.key,
|
||||
submodel_type=submodel_type,
|
||||
stats_name=stats_name,
|
||||
)
|
||||
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
|
||||
|
||||
def get_size_fs(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Init file for ModelCache."""
|
||||
|
||||
from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
|
||||
from .model_cache_default import ModelCache # noqa F401
|
||||
|
||||
_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord:
|
||||
"""
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
|
||||
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||
context manager call `model_on_device()`.
|
||||
"""
|
||||
|
||||
key: str
|
||||
model: Any
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock this record."""
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock this record."""
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
return self._locks > 0
|
||||
@@ -0,0 +1,15 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
"""Collect statistics on cache performance."""
|
||||
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
@@ -0,0 +1,93 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class CachedModelOnlyFullLoad:
|
||||
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
|
||||
"""Initialize a CachedModelOnlyFullLoad.
|
||||
Args:
|
||||
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
|
||||
compute_device (torch.device): The compute device to move the model to.
|
||||
total_bytes (int): The total size (in bytes) of all the weights in the model.
|
||||
"""
|
||||
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
self._offload_device = torch.device("cpu")
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
|
||||
if isinstance(model, torch.nn.Module):
|
||||
self._cpu_state_dict = model.state_dict()
|
||||
|
||||
self._total_bytes = total_bytes
|
||||
self._is_in_vram = False
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
|
||||
"""Get a read-only copy of the model's state dict in RAM."""
|
||||
# TODO(ryand): Document this better.
|
||||
return self._cpu_state_dict
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return self._total_bytes
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._is_in_vram:
|
||||
return self._total_bytes
|
||||
else:
|
||||
return 0
|
||||
|
||||
def is_in_vram(self) -> bool:
|
||||
"""Return true if the model is currently in VRAM."""
|
||||
return self._is_in_vram
|
||||
|
||||
def full_load_to_vram(self) -> int:
|
||||
"""Load all weights into VRAM (if supported by the model).
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
if self._is_in_vram:
|
||||
# Already in VRAM.
|
||||
return 0
|
||||
|
||||
if not hasattr(self._model, "to"):
|
||||
# Model doesn't support moving to a device.
|
||||
return 0
|
||||
|
||||
if self._cpu_state_dict is not None:
|
||||
new_state_dict: dict[str, torch.Tensor] = {}
|
||||
for k, v in self._cpu_state_dict.items():
|
||||
new_state_dict[k] = v.to(self._compute_device, copy=True)
|
||||
self._model.load_state_dict(new_state_dict, assign=True)
|
||||
self._model.to(self._compute_device)
|
||||
|
||||
self._is_in_vram = True
|
||||
return self._total_bytes
|
||||
|
||||
def full_unload_from_vram(self) -> int:
|
||||
"""Unload all weights from VRAM.
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
if not self._is_in_vram:
|
||||
# Already in RAM.
|
||||
return 0
|
||||
|
||||
if self._cpu_state_dict is not None:
|
||||
self._model.load_state_dict(self._cpu_state_dict, assign=True)
|
||||
self._model.to(self._offload_device)
|
||||
|
||||
self._is_in_vram = False
|
||||
return self._total_bytes
|
||||
@@ -0,0 +1,204 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
def set_nested_attr(obj: object, attr: str, value: object):
|
||||
"""A helper function that extends setattr() to support nested attributes.
|
||||
|
||||
Example:
|
||||
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
|
||||
"""
|
||||
attrs = attr.split(".")
|
||||
for attr in attrs[:-1]:
|
||||
obj = getattr(obj, attr)
|
||||
setattr(obj, attrs[-1], value)
|
||||
|
||||
|
||||
class CachedModelWithPartialLoad:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
|
||||
# TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting).
|
||||
# Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes.
|
||||
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
|
||||
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
|
||||
|
||||
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
|
||||
"""Find all modules that support autocasting."""
|
||||
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore
|
||||
|
||||
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
|
||||
keys_in_modules_that_do_not_support_autocast: set[str] = set()
|
||||
for key in self._cpu_state_dict.keys():
|
||||
for module_name in self._modules_that_support_autocast.keys():
|
||||
if key.startswith(module_name):
|
||||
break
|
||||
else:
|
||||
keys_in_modules_that_do_not_support_autocast.add(key)
|
||||
return keys_in_modules_that_do_not_support_autocast
|
||||
|
||||
def _move_non_persistent_buffers_to_device(self, device: torch.device):
|
||||
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
|
||||
so we need to move them manually.
|
||||
"""
|
||||
# HACK(ryand): Typically, non-persistent buffers are moved when calling module.to(device). We don't move entire
|
||||
# modules, because we manage the devices of individual tensors using the state dict. Since non-persistent
|
||||
# buffers are not included in the state dict, we need to handle them manually. The only way to do this is by
|
||||
# using private torch.nn.Module attributes.
|
||||
for module in self._model.modules():
|
||||
for name, buffer in module.named_buffers():
|
||||
if name in module._non_persistent_buffers_set:
|
||||
module._buffers[name] = buffer.to(device, copy=True)
|
||||
|
||||
def _set_autocast_enabled_in_all_modules(self, enabled: bool):
|
||||
"""Set autocast_enabled flag in all modules that support device autocasting."""
|
||||
for module in self._modules_that_support_autocast.values():
|
||||
module.set_device_autocasting_enabled(enabled)
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
|
||||
"""Get a read-only copy of the model's state dict in RAM."""
|
||||
# TODO(ryand): Document this better.
|
||||
return self._cpu_state_dict
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return self._total_bytes
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._cur_vram_bytes is None:
|
||||
cur_state_dict = self._model.state_dict()
|
||||
self._cur_vram_bytes = sum(
|
||||
calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes
|
||||
|
||||
def full_load_to_vram(self) -> int:
|
||||
"""Load all weights into VRAM."""
|
||||
return self.partial_load_to_vram(self.total_bytes())
|
||||
|
||||
def full_unload_from_vram(self) -> int:
|
||||
"""Unload all weights from VRAM."""
|
||||
return self.partial_unload_from_vram(self.total_bytes())
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
# TODO(ryand): Handle the case where an exception is thrown while loading or unloading weights. At the very
|
||||
# least, we should reset self._cur_vram_bytes to None.
|
||||
|
||||
vram_bytes_loaded = 0
|
||||
|
||||
cur_state_dict = self._model.state_dict()
|
||||
|
||||
# First, process the keys that *must* be loaded into VRAM.
|
||||
for key in self._keys_in_modules_that_do_not_support_autocast:
|
||||
param = cur_state_dict[key]
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
param_size = calc_tensor_size(param)
|
||||
cur_state_dict[key] = param.to(self._compute_device, copy=True)
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if vram_bytes_loaded > vram_bytes_to_load:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
logger.warning(
|
||||
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
|
||||
"requested. This is the minimum set of weights in VRAM required to run the model."
|
||||
)
|
||||
|
||||
# Next, process the keys that can optionally be loaded into VRAM.
|
||||
fully_loaded = True
|
||||
for key, param in cur_state_dict.items():
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
param_size = calc_tensor_size(param)
|
||||
if vram_bytes_loaded + param_size > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
fully_loaded = False
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = param.to(self._compute_device, copy=True)
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if vram_bytes_loaded > 0:
|
||||
# We load the entire state dict, not just the parameters that changed, in case there are modules that
|
||||
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
|
||||
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
|
||||
self._model.load_state_dict(cur_state_dict, assign=True)
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes += vram_bytes_loaded
|
||||
|
||||
if fully_loaded:
|
||||
self._set_autocast_enabled_in_all_modules(False)
|
||||
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
|
||||
else:
|
||||
self._set_autocast_enabled_in_all_modules(True)
|
||||
|
||||
# Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in
|
||||
# the vram_bytes_loaded tracking.
|
||||
self._move_non_persistent_buffers_to_device(self._compute_device)
|
||||
|
||||
return vram_bytes_loaded
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
vram_bytes_freed = 0
|
||||
|
||||
offload_device = "cpu"
|
||||
cur_state_dict = self._model.state_dict()
|
||||
for key, param in cur_state_dict.items():
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
break
|
||||
|
||||
if param.device.type == offload_device:
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = self._cpu_state_dict[key]
|
||||
vram_bytes_freed += calc_tensor_size(param)
|
||||
|
||||
if vram_bytes_freed > 0:
|
||||
self._model.load_state_dict(cur_state_dict, assign=True)
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes -= vram_bytes_freed
|
||||
|
||||
# We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom
|
||||
# layers.
|
||||
self._set_autocast_enabled_in_all_modules(True)
|
||||
return vram_bytes_freed
|
||||
@@ -1,11 +1,9 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
""" """
|
||||
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -13,13 +11,11 @@ import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
|
||||
CacheRecord,
|
||||
CacheStats,
|
||||
ModelCacheBase,
|
||||
ModelLockerBase,
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
apply_custom_layers_to_model,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -31,7 +27,15 @@ GB = 2**30
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
# TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels.
|
||||
def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
||||
if submodel_type:
|
||||
return f"{model_key}:{submodel_type.value}"
|
||||
else:
|
||||
return model_key
|
||||
|
||||
|
||||
class ModelCache:
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
@@ -70,7 +74,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
lazy_offloading: bool = True,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
@@ -82,7 +85,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
@@ -100,29 +102,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cached_models: Dict[str, CacheRecord] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
return self._lazy_offloading
|
||||
|
||||
@property
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
return self._storage_device
|
||||
|
||||
@property
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
return self._execution_device
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the cap on cache size."""
|
||||
@@ -153,49 +135,39 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
self._stats = stats
|
||||
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
total = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: AnyModel,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
"""Insert model into the cache."""
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(self.logger, model)
|
||||
size = calc_model_size_by_data(self._logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
running_on_cpu = self.execution_device == torch.device("cpu")
|
||||
# Inject custom modules into the model.
|
||||
if isinstance(model, torch.nn.Module):
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
running_on_cpu = self._execution_device == torch.device("cpu")
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
) -> CacheRecord:
|
||||
"""Retrieve a model from the cache.
|
||||
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
:param key: Model key
|
||||
:param stats_name: A human-readable id for the model for the purposes of stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
Raises IndexError if the model is not in the cache.
|
||||
"""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
@@ -210,20 +182,52 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._get_cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack = [k for k in self._cache_stack if k != key]
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
|
||||
return cache_entry
|
||||
|
||||
def lock(self, key: str) -> None:
|
||||
"""Lock a model for use and move it into VRAM."""
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry.lock()
|
||||
|
||||
try:
|
||||
if self._lazy_offloading:
|
||||
self._offload_unlocked_models(cache_entry.size)
|
||||
self._move_model_to_device(cache_entry, self._execution_device)
|
||||
cache_entry.loaded = True
|
||||
self._logger.debug(f"Locking {cache_entry.key} in {self._execution_device}")
|
||||
self._print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
cache_entry.unlock()
|
||||
raise
|
||||
|
||||
def unlock(self, key: str) -> None:
|
||||
"""Unlock a model."""
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry.unlock()
|
||||
if not self._lazy_offloading:
|
||||
self._offload_unlocked_models(0)
|
||||
self._print_cuda_stats()
|
||||
|
||||
def _get_cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
total = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
@@ -236,30 +240,30 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
else:
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
def _offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
self._logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
self._move_model_to_device(cache_entry, self._storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
|
||||
:param cache_entry: The CacheRecord for the model
|
||||
@@ -267,7 +271,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
source_device = cache_entry.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
@@ -294,7 +298,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
try:
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self.storage_device:
|
||||
if target_device == self._storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
@@ -309,7 +313,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
@@ -331,7 +335,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
@@ -339,24 +343,24 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
def _print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||
ram = "%4.2fG" % (self._get_cache_size() / GB)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
locked_in_vram_models = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
if hasattr(cache_record.model, "device"):
|
||||
if cache_record.model.device == self.storage_device:
|
||||
if cache_record.model.device == self._storage_device:
|
||||
in_ram_models += 1
|
||||
else:
|
||||
in_vram_models += 1
|
||||
if cache_record.locked:
|
||||
locked_in_vram_models += 1
|
||||
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||
)
|
||||
@@ -369,16 +373,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
garbage-collected.
|
||||
"""
|
||||
bytes_needed = size
|
||||
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
maximum_size = self._max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self._get_cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
self._logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
@@ -386,12 +390,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
)
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
self._logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
@@ -419,8 +423,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
self._logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
@@ -1,221 +0,0 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModel, SubModelType
|
||||
|
||||
|
||||
class ModelLockerBase(ABC):
|
||||
"""Base class for the model locker used by the loader."""
|
||||
|
||||
@abstractmethod
|
||||
def lock(self) -> AnyModel:
|
||||
"""Lock the contained model and move it into VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unlock(self) -> None:
|
||||
"""Unlock the contained model, and remove it from VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model."""
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord(Generic[T]):
|
||||
"""
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
|
||||
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||
context manager call `model_on_device()`.
|
||||
"""
|
||||
|
||||
key: str
|
||||
model: T
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock this record."""
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock this record."""
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
return self._locks > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
"""Collect statistics on cache performance."""
|
||||
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Virtual base class for RAM model cache."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the maximum size the RAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self, value: float) -> float:
|
||||
"""Set the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload from VRAM any models not actively in use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
pass
|
||||
|
||||
@stats.setter
|
||||
@abstractmethod
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: T,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log debugging information on CUDA usage."""
|
||||
pass
|
||||
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
Base class and implementation of a class that moves models in and out of VRAM.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
|
||||
CacheRecord,
|
||||
ModelCacheBase,
|
||||
ModelLockerBase,
|
||||
)
|
||||
|
||||
|
||||
class ModelLocker(ModelLockerBase):
|
||||
"""Internal class that mediates movement in and out of GPU."""
|
||||
|
||||
def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]):
|
||||
"""
|
||||
Initialize the model locker.
|
||||
|
||||
:param cache: The ModelCache object
|
||||
:param cache_entry: The entry in the model cache
|
||||
"""
|
||||
self._cache = cache
|
||||
self._cache_entry = cache_entry
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without moving it around."""
|
||||
return self._cache_entry.model
|
||||
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return self._cache_entry.state_dict
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
self._cache_entry.lock()
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||
self._cache_entry.loaded = True
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||
self._cache.print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
|
||||
return self.model
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Call upon exit from context."""
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(0)
|
||||
self._cache.print_cuda_stats()
|
||||
@@ -0,0 +1,15 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
|
||||
|
||||
|
||||
def cast_to_device(t: T, to_device: torch.device) -> T:
|
||||
"""Helper function to cast an optional tensor to a target device."""
|
||||
if t is None:
|
||||
return t
|
||||
|
||||
if t.device.type != to_device.type:
|
||||
return t.to(to_device)
|
||||
return t
|
||||
@@ -0,0 +1,8 @@
|
||||
|
||||
This directory contains custom implementations of common torch.nn.Module classes that add support for:
|
||||
- Streaming weights to the execution device
|
||||
- Applying sidecar patches at execution time (e.g. sidecar LoRA layers)
|
||||
|
||||
Each custom class sub-classes the original module type that is is replacing, so the following properties are preserved:
|
||||
- `isinstance(m, torch.nn.OrginalModule)` should still work.
|
||||
- Patching the weights directly (e.g. for LoRA) should still work. (Of course, this is not possible for quantized layers, hence the sidecar support.)
|
||||
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
|
||||
add_nullable_tensors,
|
||||
)
|
||||
|
||||
|
||||
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
|
||||
# Prepare the original parameters for the patch aggregation.
|
||||
orig_params = {"weight": weight, "bias": bias}
|
||||
# Filter out None values.
|
||||
orig_params = {k: v for k, v in orig_params.items() if v is not None}
|
||||
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(
|
||||
patches_and_weights=self._patches_and_weights,
|
||||
orig_params=orig_params,
|
||||
device=input.device,
|
||||
)
|
||||
|
||||
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
|
||||
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(input)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
||||
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
|
||||
add_nullable_tensors,
|
||||
)
|
||||
|
||||
|
||||
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
|
||||
# Prepare the original parameters for the patch aggregation.
|
||||
orig_params = {"weight": weight, "bias": bias}
|
||||
# Filter out None values.
|
||||
orig_params = {k: v for k, v in orig_params.items() if v is not None}
|
||||
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(
|
||||
patches_and_weights=self._patches_and_weights,
|
||||
orig_params=orig_params,
|
||||
device=input.device,
|
||||
)
|
||||
|
||||
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
|
||||
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(input)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
||||
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
|
||||
class CustomEmbedding(torch.nn.Embedding, CustomModuleMixin):
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
return torch.nn.functional.embedding(
|
||||
input,
|
||||
weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
raise RuntimeError("Embedding layers do not support patches")
|
||||
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
||||
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
class CustomFluxRMSNorm(RMSNorm, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Currently, CustomFluxRMSNorm layers only support patching with a single SetParameterLayer.
|
||||
assert len(self._patches_and_weights) == 1
|
||||
patch, _patch_weight = self._patches_and_weights[0]
|
||||
assert isinstance(patch, SetParameterLayer)
|
||||
assert patch.param_name == "scale"
|
||||
|
||||
scale = cast_to_device(patch.weight, x.device)
|
||||
|
||||
# Apply the patch.
|
||||
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
|
||||
# be handled.
|
||||
return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6)
|
||||
|
||||
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
scale = cast_to_device(self.scale, x.device)
|
||||
return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(x)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
||||
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
|
||||
class CustomGroupNorm(torch.nn.GroupNorm, CustomModuleMixin):
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
raise RuntimeError("GroupNorm layers do not support patches")
|
||||
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
||||
@@ -0,0 +1,44 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
|
||||
autocast_linear_forward_sidecar_patches,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
|
||||
|
||||
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)
|
||||
|
||||
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
matmul_state = bnb.MatmulLtState()
|
||||
matmul_state.threshold = self.state.threshold
|
||||
matmul_state.has_fp16_weights = self.state.has_fp16_weights
|
||||
matmul_state.use_pool = self.state.use_pool
|
||||
matmul_state.is_training = self.training
|
||||
# The underlying InvokeInt8Params weight must already be quantized.
|
||||
assert self.weight.CB is not None
|
||||
matmul_state.CB = cast_to_device(self.weight.CB, x.device)
|
||||
matmul_state.SCB = cast_to_device(self.weight.SCB, x.device)
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually.
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
# NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but
|
||||
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
|
||||
# on the wrong device.
|
||||
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(x)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
||||
@@ -0,0 +1,62 @@
|
||||
import copy
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
|
||||
autocast_linear_forward_sidecar_patches,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
|
||||
class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)
|
||||
|
||||
def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self)
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if not self.compute_type_is_set:
|
||||
self.set_compute_type(x)
|
||||
self.compute_type_is_set = True
|
||||
|
||||
inp_dtype = x.dtype
|
||||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||
|
||||
# HACK(ryand): Casting self.weight to the device also casts the self.weight.quant_state in-place (i.e. it
|
||||
# does not follow the tensor semantics of returning a new copy when converting to a different device). This
|
||||
# means that quant_state elements that started on the CPU would be left on the GPU, which we don't want. To
|
||||
# avoid this side effect we make a shallow copy of the original quant_state so that we can restore it. Fixing
|
||||
# this properly would require more invasive changes to the bitsandbytes library.
|
||||
|
||||
# Make a shallow copy of the quant_state so that we can undo the in-place modification that occurs when casting
|
||||
# to a new device.
|
||||
old_quant_state = copy.copy(self.weight.quant_state)
|
||||
weight = cast_to_device(self.weight, x.device)
|
||||
self.weight.quant_state = old_quant_state
|
||||
|
||||
# For some reason, the quant_state.to(...) implementation fails to cast the quant_state.code field. We do this
|
||||
# manually here.
|
||||
weight.quant_state.code = cast_to_device(weight.quant_state.code, x.device)
|
||||
|
||||
bias = cast_to_device(self.bias, x.device)
|
||||
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(x)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
||||
@@ -0,0 +1,106 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a sidecar linear LoRALayer."""
|
||||
x = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
|
||||
x *= lora_weight * lora_layer.scale()
|
||||
return x
|
||||
|
||||
|
||||
def concatenated_lora_forward(
|
||||
input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
|
||||
) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a sidecar ConcatenatedLoRALayer."""
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= lora_weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
def autocast_linear_forward_sidecar_patches(
|
||||
orig_module: torch.nn.Linear, input: torch.Tensor, patches_and_weights: list[tuple[BaseLayerPatch, float]]
|
||||
) -> torch.Tensor:
|
||||
"""A function that runs a linear layer (quantized or non-quantized) with sidecar patches for a linear layer.
|
||||
Compatible with both quantized and non-quantized Linear layers.
|
||||
"""
|
||||
# First, apply the original linear layer.
|
||||
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
|
||||
# change the linear layer's in_features.
|
||||
orig_input = input
|
||||
input = orig_input[..., : orig_module.in_features]
|
||||
output = orig_module._autocast_forward(input)
|
||||
|
||||
# Then, apply layers for which we have optimized implementations.
|
||||
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
|
||||
for patch, patch_weight in patches_and_weights:
|
||||
# Shallow copy the patch so that we can cast it to the target device without modifying the original patch.
|
||||
patch = copy.copy(patch)
|
||||
patch.to(input.device)
|
||||
|
||||
if isinstance(patch, FluxControlLoRALayer):
|
||||
# Note that we use the original input here, not the sliced input.
|
||||
output += linear_lora_forward(orig_input, patch, patch_weight)
|
||||
elif isinstance(patch, LoRALayer):
|
||||
output += linear_lora_forward(input, patch, patch_weight)
|
||||
elif isinstance(patch, ConcatenatedLoRALayer):
|
||||
output += concatenated_lora_forward(input, patch, patch_weight)
|
||||
else:
|
||||
unprocessed_patches_and_weights.append((patch, patch_weight))
|
||||
|
||||
# Finally, apply any remaining patches.
|
||||
if len(unprocessed_patches_and_weights) > 0:
|
||||
# Prepare the original parameters for the patch aggregation.
|
||||
orig_params = {"weight": orig_module.weight, "bias": orig_module.bias}
|
||||
# Filter out None values.
|
||||
orig_params = {k: v for k, v in orig_params.items() if v is not None}
|
||||
|
||||
aggregated_param_residuals = orig_module._aggregate_patch_parameters(
|
||||
unprocessed_patches_and_weights, orig_params=orig_params, device=input.device
|
||||
)
|
||||
output += torch.nn.functional.linear(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
|
||||
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights)
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device)
|
||||
bias = cast_to_device(self.bias, input.device)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
return self._autocast_forward_with_patches(input)
|
||||
elif self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
||||
@@ -0,0 +1,63 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
|
||||
|
||||
class CustomModuleMixin:
|
||||
"""A mixin class for custom modules that enables device autocasting of module parameters."""
|
||||
|
||||
def __init__(self):
|
||||
self._device_autocasting_enabled = False
|
||||
self._patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
|
||||
|
||||
def set_device_autocasting_enabled(self, enabled: bool):
|
||||
"""Pass True to enable autocasting of module parameters to the same device as the input tensor. Pass False to
|
||||
disable autocasting, which results in slightly faster execution speed when we know that device autocasting is
|
||||
not needed.
|
||||
"""
|
||||
self._device_autocasting_enabled = enabled
|
||||
|
||||
def is_device_autocasting_enabled(self) -> bool:
|
||||
"""Check if device autocasting is enabled for the module."""
|
||||
return self._device_autocasting_enabled
|
||||
|
||||
def add_patch(self, patch: BaseLayerPatch, patch_weight: float):
|
||||
"""Add a patch to the module."""
|
||||
self._patches_and_weights.append((patch, patch_weight))
|
||||
|
||||
def clear_patches(self):
|
||||
"""Clear all patches from the module."""
|
||||
self._patches_and_weights = []
|
||||
|
||||
def get_num_patches(self) -> int:
|
||||
"""Get the number of patches in the module."""
|
||||
return len(self._patches_and_weights)
|
||||
|
||||
def _aggregate_patch_parameters(
|
||||
self,
|
||||
patches_and_weights: list[tuple[BaseLayerPatch, float]],
|
||||
orig_params: dict[str, torch.Tensor],
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
"""Helper function that aggregates the parameters from all patches into a single dict."""
|
||||
params: dict[str, torch.Tensor] = {}
|
||||
|
||||
for patch, patch_weight in patches_and_weights:
|
||||
if device is not None:
|
||||
# Shallow copy the patch so that we can cast it to the target device without modifying the original patch.
|
||||
patch = copy.copy(patch)
|
||||
patch.to(device)
|
||||
|
||||
# TODO(ryand): `self` could be a quantized module. Depending on what the patch is doing with the original
|
||||
# parameters, this might fail or return incorrect results.
|
||||
layer_params = patch.get_parameters(orig_params, weight=patch_weight)
|
||||
|
||||
for param_name, param_weight in layer_params.items():
|
||||
if param_name not in params:
|
||||
params[param_name] = param_weight
|
||||
else:
|
||||
params[param_name] += param_weight
|
||||
|
||||
return params
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import overload
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@overload
|
||||
def add_nullable_tensors(a: None, b: None) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
def add_nullable_tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
@overload
|
||||
def add_nullable_tensors(a: torch.Tensor, b: None) -> torch.Tensor: ...
|
||||
|
||||
|
||||
@overload
|
||||
def add_nullable_tensors(a: None, b: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
def add_nullable_tensors(a: torch.Tensor | None, b: torch.Tensor | None) -> torch.Tensor | None:
|
||||
if a is None and b is None:
|
||||
return None
|
||||
elif a is None:
|
||||
return b
|
||||
elif b is None:
|
||||
return a
|
||||
else:
|
||||
return a + b
|
||||
@@ -0,0 +1,105 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import (
|
||||
CustomConv1d,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv2d import (
|
||||
CustomConv2d,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import (
|
||||
CustomEmbedding,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import (
|
||||
CustomFluxRMSNorm,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import (
|
||||
CustomGroupNorm,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
|
||||
CustomLinear,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = {
|
||||
torch.nn.Linear: CustomLinear,
|
||||
torch.nn.Conv1d: CustomConv1d,
|
||||
torch.nn.Conv2d: CustomConv2d,
|
||||
torch.nn.GroupNorm: CustomGroupNorm,
|
||||
torch.nn.Embedding: CustomEmbedding,
|
||||
RMSNorm: CustomFluxRMSNorm,
|
||||
}
|
||||
|
||||
try:
|
||||
# These dependencies are not expected to be present on MacOS.
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_8_bit_lt import (
|
||||
CustomInvokeLinear8bitLt,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import (
|
||||
CustomInvokeLinearNF4,
|
||||
)
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
|
||||
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
|
||||
|
||||
AUTOCAST_MODULE_TYPE_MAPPING[InvokeLinear8bitLt] = CustomInvokeLinear8bitLt
|
||||
AUTOCAST_MODULE_TYPE_MAPPING[InvokeLinearNF4] = CustomInvokeLinearNF4
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
AUTOCAST_MODULE_TYPE_MAPPING_INVERSE = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()}
|
||||
|
||||
|
||||
T = TypeVar("T", bound=torch.nn.Module)
|
||||
|
||||
|
||||
def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[T]) -> T:
|
||||
# HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an
|
||||
# existing layer instance without calling __init__() on the original layer class. We achieve this by copying
|
||||
# the attributes from the original layer instance to the new instance.
|
||||
custom_layer = custom_layer_type.__new__(custom_layer_type)
|
||||
# Note that we share the __dict__.
|
||||
# TODO(ryand): In the future, we may want to do a shallow copy of the __dict__.
|
||||
custom_layer.__dict__ = module_to_wrap.__dict__
|
||||
|
||||
# Initialize the CustomModuleMixin fields.
|
||||
CustomModuleMixin.__init__(custom_layer) # type: ignore
|
||||
return custom_layer
|
||||
|
||||
|
||||
def unwrap_custom_layer(custom_layer: torch.nn.Module, original_layer_type: type[torch.nn.Module]):
|
||||
# HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an
|
||||
# existing layer instance without calling __init__() on the original layer class. We achieve this by copying
|
||||
# the attributes from the original layer instance to the new instance.
|
||||
original_layer = original_layer_type.__new__(original_layer_type)
|
||||
# Note that we share the __dict__.
|
||||
# TODO(ryand): In the future, we may want to do a shallow copy of the __dict__ and strip out the CustomModuleMixin
|
||||
# fields.
|
||||
original_layer.__dict__ = custom_layer.__dict__
|
||||
return original_layer
|
||||
|
||||
|
||||
def apply_custom_layers_to_model(module: torch.nn.Module, device_autocasting_enabled: bool = False):
|
||||
for name, submodule in module.named_children():
|
||||
override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(submodule), None)
|
||||
if override_type is not None:
|
||||
custom_layer = wrap_custom_layer(submodule, override_type)
|
||||
# TODO(ryand): In the future, we should manage this flag on a per-module basis.
|
||||
custom_layer.set_device_autocasting_enabled(device_autocasting_enabled)
|
||||
setattr(module, name, custom_layer)
|
||||
else:
|
||||
# Recursively apply to submodules
|
||||
apply_custom_layers_to_model(submodule, device_autocasting_enabled)
|
||||
|
||||
|
||||
def remove_custom_layers_from_model(module: torch.nn.Module):
|
||||
for name, submodule in module.named_children():
|
||||
override_type = AUTOCAST_MODULE_TYPE_MAPPING_INVERSE.get(type(submodule), None)
|
||||
if override_type is not None:
|
||||
setattr(module, name, unwrap_custom_layer(submodule, override_type))
|
||||
else:
|
||||
remove_custom_layers_from_model(submodule)
|
||||
@@ -9,14 +9,6 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@@ -26,12 +18,27 @@ from invokeai.backend.model_manager import (
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
|
||||
is_state_dict_likely_flux_control,
|
||||
lora_model_from_flux_control_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.LyCORIS)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.Diffusers)
|
||||
class LoRALoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
@@ -40,7 +47,7 @@ class LoRALoader(ModelLoader):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
super().__init__(app_config, logger, ram_cache)
|
||||
@@ -75,7 +82,10 @@ class LoRALoader(ModelLoader):
|
||||
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
|
||||
elif config.format == ModelFormat.LyCORIS:
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_flux_control(state_dict=state_dict):
|
||||
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
|
||||
@@ -25,6 +25,7 @@ from invokeai.backend.model_manager.config import (
|
||||
DiffusersConfigBase,
|
||||
MainCheckpointConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
@@ -132,5 +133,5 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||
self._ram_cache.put(get_model_cache_key(config.key, subtype), model=submodel)
|
||||
return getattr(pipeline, submodel_type.value)
|
||||
|
||||
@@ -15,9 +15,9 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
@@ -43,7 +43,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
(
|
||||
TextualInversionModelRaw,
|
||||
IPAdapter,
|
||||
LoRAModelRaw,
|
||||
ModelPatchRaw,
|
||||
SpandrelImageToImageModel,
|
||||
GroundingDinoPipeline,
|
||||
SegmentAnythingPipeline,
|
||||
|
||||
@@ -15,10 +15,6 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
@@ -43,6 +39,13 @@ from invokeai.backend.model_manager.util.model_util import (
|
||||
lora_token_vector_length,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
@@ -199,8 +202,8 @@ class ModelProbe(object):
|
||||
fields["default_settings"] = fields.get("default_settings")
|
||||
|
||||
if not fields["default_settings"]:
|
||||
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}:
|
||||
fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"])
|
||||
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter, ModelType.ControlLoRa}:
|
||||
fields["default_settings"] = get_default_settings_control_adapters(fields["name"])
|
||||
elif fields["type"] is ModelType.Main:
|
||||
fields["default_settings"] = get_default_settings_main(fields["base"])
|
||||
|
||||
@@ -258,6 +261,9 @@ class ModelProbe(object):
|
||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
if isinstance(ckpt, dict) and is_state_dict_likely_flux_control(ckpt):
|
||||
return ModelType.ControlLoRa
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(
|
||||
(
|
||||
@@ -497,7 +503,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
|
||||
}
|
||||
|
||||
|
||||
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
||||
def get_default_settings_control_adapters(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
||||
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
|
||||
model_name_lower = model_name.lower()
|
||||
if k in model_name_lower:
|
||||
@@ -624,8 +630,10 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
return ModelFormat.LyCORIS
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
|
||||
self.checkpoint
|
||||
if (
|
||||
is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
|
||||
or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
|
||||
or is_state_dict_likely_flux_control(self.checkpoint)
|
||||
):
|
||||
return BaseModelType.Flux
|
||||
|
||||
@@ -1034,6 +1042,7 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
@@ -1046,6 +1055,7 @@ ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelI
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
|
||||
@@ -488,6 +488,22 @@ union_cnet_flux = StarterModel(
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
# endregion
|
||||
# region Control LoRA
|
||||
flux_canny_control_lora = StarterModel(
|
||||
name="Hard Edge Detection (Canny)",
|
||||
base=BaseModelType.Flux,
|
||||
source="black-forest-labs/FLUX.1-Canny-dev-lora::flux1-canny-dev-lora.safetensors",
|
||||
description="Uses detected edges in the image to control composition.",
|
||||
type=ModelType.ControlLoRa,
|
||||
)
|
||||
flux_depth_control_lora = StarterModel(
|
||||
name="Depth Map",
|
||||
base=BaseModelType.Flux,
|
||||
source="black-forest-labs/FLUX.1-Depth-dev-lora::flux1-depth-dev-lora.safetensors",
|
||||
description="Uses depth information in the image to control the depth in the generation.",
|
||||
type=ModelType.ControlLoRa,
|
||||
)
|
||||
# endregion
|
||||
# region T2I Adapter
|
||||
t2i_canny_sd1 = StarterModel(
|
||||
name="Hard Edge Detection (canny)",
|
||||
@@ -630,6 +646,8 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
tile_sdxl,
|
||||
union_cnet_sdxl,
|
||||
union_cnet_flux,
|
||||
flux_canny_control_lora,
|
||||
flux_depth_control_lora,
|
||||
t2i_canny_sd1,
|
||||
t2i_sketch_sd1,
|
||||
t2i_depth_sd1,
|
||||
@@ -688,6 +706,8 @@ flux_bundle: list[StarterModel] = [
|
||||
clip_l_encoder,
|
||||
union_cnet_flux,
|
||||
ip_adapter_flux,
|
||||
flux_canny_control_lora,
|
||||
flux_depth_control_lora,
|
||||
]
|
||||
|
||||
STARTER_BUNDLES: dict[str, list[StarterModel]] = {
|
||||
|
||||
@@ -5,17 +5,14 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
|
||||
|
||||
@@ -176,180 +173,3 @@ class ModelPatcher:
|
||||
assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute?
|
||||
if did_apply_freeu:
|
||||
unet.disable_freeu()
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
# based on
|
||||
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: IAIOnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
from invokeai.backend.models.base import IAIOnnxRuntimeModel
|
||||
|
||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
orig_weights = {}
|
||||
|
||||
try:
|
||||
blended_loras: Dict[str, torch.Tensor] = {}
|
||||
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_key = layer_key.replace(prefix, "")
|
||||
# TODO: rewrite to pass original tensor weight(required by ia3)
|
||||
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
|
||||
if layer_key in blended_loras:
|
||||
blended_loras[layer_key] += layer_weight
|
||||
else:
|
||||
blended_loras[layer_key] = layer_weight
|
||||
|
||||
node_names = {}
|
||||
for node in model.nodes.values():
|
||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||
|
||||
for layer_key, lora_weight in blended_loras.items():
|
||||
conv_key = layer_key + "_Conv"
|
||||
gemm_key = layer_key + "_Gemm"
|
||||
matmul_key = layer_key + "_MatMul"
|
||||
|
||||
if conv_key in node_names or gemm_key in node_names:
|
||||
if conv_key in node_names:
|
||||
conv_node = model.nodes[node_names[conv_key]]
|
||||
else:
|
||||
conv_node = model.nodes[node_names[gemm_key]]
|
||||
|
||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||
orig_weight = model.tensors[weight_name]
|
||||
|
||||
if orig_weight.shape[-2:] == (1, 1):
|
||||
if lora_weight.shape[-2:] == (1, 1):
|
||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
|
||||
else:
|
||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
|
||||
|
||||
new_weight = np.expand_dims(new_weight, (2, 3))
|
||||
else:
|
||||
if orig_weight.shape != lora_weight.shape:
|
||||
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
|
||||
else:
|
||||
new_weight = orig_weight + lora_weight
|
||||
|
||||
orig_weights[weight_name] = orig_weight
|
||||
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
|
||||
|
||||
elif matmul_key in node_names:
|
||||
weight_node = model.nodes[node_names[matmul_key]]
|
||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||
|
||||
orig_weight = model.tensors[matmul_name]
|
||||
new_weight = orig_weight + lora_weight.transpose()
|
||||
|
||||
orig_weights[matmul_name] = orig_weight
|
||||
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
|
||||
|
||||
else:
|
||||
# warn? err?
|
||||
pass
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# restore original weights
|
||||
for name, orig_weight in orig_weights.items():
|
||||
model.tensors[name] = orig_weight
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
cls,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: IAIOnnxRuntimeModel,
|
||||
ti_list: List[Tuple[str, Any]],
|
||||
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
||||
from invokeai.backend.models.base import IAIOnnxRuntimeModel
|
||||
|
||||
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
orig_embeddings = None
|
||||
|
||||
try:
|
||||
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
||||
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
||||
# exiting this `apply_ti(...)` context manager.
|
||||
#
|
||||
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
|
||||
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
||||
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
|
||||
def _get_trigger(ti_name: str, index: int) -> str:
|
||||
trigger = ti_name
|
||||
if index > 0:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
# modify text_encoder
|
||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti_name, ti in ti_list:
|
||||
if ti.embedding_2 is not None:
|
||||
ti_embedding = (
|
||||
ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
|
||||
)
|
||||
else:
|
||||
ti_embedding = ti.embedding
|
||||
|
||||
for i in range(ti_embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||
|
||||
embeddings = np.concatenate(
|
||||
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
for ti_name, _ in ti_list:
|
||||
ti_tokens = []
|
||||
for i in range(ti_embedding.shape[0]):
|
||||
embedding = ti_embedding[i].detach().numpy()
|
||||
trigger = _get_trigger(ti_name, i)
|
||||
|
||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||
if token_id == ti_tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||
|
||||
if embeddings[token_id].shape != embedding.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
|
||||
f" {embedding.shape[0]}, but the current model has token dimension"
|
||||
f" {embeddings[token_id].shape[0]}."
|
||||
)
|
||||
|
||||
embeddings[token_id] = embedding
|
||||
ti_tokens.append(token_id)
|
||||
|
||||
if len(ti_tokens) > 1:
|
||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||
|
||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(
|
||||
orig_embeddings.dtype
|
||||
)
|
||||
|
||||
yield ti_tokenizer, ti_manager
|
||||
|
||||
finally:
|
||||
# restore
|
||||
if orig_embeddings is not None:
|
||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings
|
||||
|
||||
277
invokeai/backend/patches/layer_patcher.py
Normal file
277
invokeai/backend/patches/layer_patcher.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class LayerPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_smart_model_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[ModelPatchRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
force_direct_patching: bool = False,
|
||||
force_sidecar_patching: bool = False,
|
||||
):
|
||||
"""Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each
|
||||
module.
|
||||
"""
|
||||
|
||||
# original_weights are stored for unpatching layers that are directly patched.
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
# original_modules are stored for unpatching layers that are wrapped.
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LayerPatcher.apply_smart_model_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
force_direct_patching=force_direct_patching,
|
||||
force_sidecar_patching=force_sidecar_patching,
|
||||
)
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Restore directly patched layers.
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
cur_param = model.get_parameter(param_key)
|
||||
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
|
||||
|
||||
# Clear patches from all patched modules.
|
||||
# Note: This logic assumes no nested modules in original_modules.
|
||||
for orig_module in original_modules.values():
|
||||
orig_module.clear_patches()
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def apply_smart_model_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: ModelPatchRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
force_direct_patching: bool,
|
||||
force_sidecar_patching: bool,
|
||||
):
|
||||
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
|
||||
patching or a sidecar wrapper for each module.
|
||||
"""
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LayerPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# Decide whether to use direct patching or a sidecar patch.
|
||||
# Direct patching is preferred, because it results in better runtime speed.
|
||||
# Reasons to use sidecar patching:
|
||||
# - The module is quantized, so the caller passed force_sidecar_patching=True.
|
||||
# - The module already has sidecar patches.
|
||||
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
|
||||
# CPU, since this would double the RAM usage)
|
||||
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
|
||||
# and that the caller will set force_sidecar_patching=True if the layer is quantized.
|
||||
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
|
||||
# forcing full patching even on the CPU?
|
||||
use_sidecar_patching = False
|
||||
if force_direct_patching and force_sidecar_patching:
|
||||
raise ValueError("Cannot force both direct and sidecar patching.")
|
||||
elif force_direct_patching:
|
||||
use_sidecar_patching = False
|
||||
elif force_sidecar_patching:
|
||||
use_sidecar_patching = True
|
||||
elif module.get_num_patches() > 0:
|
||||
use_sidecar_patching = True
|
||||
elif LayerPatcher._is_any_part_of_layer_on_cpu(module):
|
||||
use_sidecar_patching = True
|
||||
|
||||
if use_sidecar_patching:
|
||||
LayerPatcher._apply_model_layer_wrapper_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
LayerPatcher._apply_model_layer_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
|
||||
return any(p.device.type == "cpu" for p in layer.parameters())
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_model_layer_patch(
|
||||
module_to_patch: torch.nn.Module,
|
||||
module_to_patch_key: str,
|
||||
patch: BaseLayerPatch,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
first_param = next(module_to_patch.parameters())
|
||||
device = first_param.device
|
||||
dtype = first_param.dtype
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
patch.to(device=device)
|
||||
patch.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, param_weight in patch.get_parameters(
|
||||
dict(module_to_patch.named_parameters(recurse=False)), weight=patch_weight
|
||||
).items():
|
||||
param_key = module_to_patch_key + "." + param_name
|
||||
module_param = module_to_patch.get_parameter(param_name)
|
||||
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
# HACK(ryand): This condition is only necessary to handle layers in FLUX control LoRAs that change the
|
||||
# shape of the original layer.
|
||||
if module_param.nelement() != param_weight.nelement():
|
||||
assert isinstance(patch, FluxControlLoRALayer)
|
||||
expanded_weight = pad_with_zeros(module_param, param_weight.shape)
|
||||
setattr(
|
||||
module_to_patch,
|
||||
param_name,
|
||||
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
|
||||
)
|
||||
module_param = expanded_weight
|
||||
|
||||
module_param += param_weight.to(dtype=dtype)
|
||||
|
||||
patch.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_model_layer_wrapper_patch(
|
||||
module_to_patch: torch.nn.Module,
|
||||
module_to_patch_key: str,
|
||||
patch: BaseLayerPatch,
|
||||
patch_weight: float,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA wrapper patch to a module."""
|
||||
# Move the LoRA layer to the same device/dtype as the orig module.
|
||||
first_param = next(module_to_patch.parameters())
|
||||
device = first_param.device
|
||||
patch.to(device=device, dtype=dtype)
|
||||
|
||||
if module_to_patch_key not in original_modules:
|
||||
original_modules[module_to_patch_key] = module_to_patch
|
||||
|
||||
module_to_patch.add_patch(patch, patch_weight)
|
||||
|
||||
@staticmethod
|
||||
def _split_parent_key(module_key: str) -> tuple[str, str]:
|
||||
"""Split a module key into its parent key and module name.
|
||||
|
||||
Args:
|
||||
module_key (str): The module key to split.
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing the parent key and module name.
|
||||
"""
|
||||
split_key = module_key.rsplit(".", 1)
|
||||
if len(split_key) == 2:
|
||||
return tuple(split_key)
|
||||
elif len(split_key) == 1:
|
||||
return "", split_key[0]
|
||||
else:
|
||||
raise ValueError(f"Invalid module key: {module_key}")
|
||||
|
||||
@staticmethod
|
||||
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
|
||||
try:
|
||||
submodule_index = int(module_name)
|
||||
# If the module name is an integer, then we use the __setitem__ method to set the submodule.
|
||||
parent_module[submodule_index] = submodule # type: ignore
|
||||
except ValueError:
|
||||
# If the module name is not an integer, then we use the setattr method to set the submodule.
|
||||
setattr(parent_module, module_name, submodule)
|
||||
|
||||
@staticmethod
|
||||
def _get_submodule(
|
||||
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
|
||||
) -> tuple[str, torch.nn.Module]:
|
||||
"""Get the submodule corresponding to the given layer key.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to search.
|
||||
layer_key (str): The layer key to search for.
|
||||
layer_key_is_flattened (bool): Whether the layer key is flattened. If flattened, then all '.' have been
|
||||
replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed
|
||||
directly without searching, but some legacy code still uses flattened keys.
|
||||
|
||||
Returns:
|
||||
tuple[str, torch.nn.Module]: A tuple containing the module key and the submodule.
|
||||
"""
|
||||
if not layer_key_is_flattened:
|
||||
return layer_key, model.get_submodule(layer_key)
|
||||
|
||||
# Handle flattened keys.
|
||||
assert "." not in layer_key
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = layer_key.split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return module_key, module
|
||||
22
invokeai/backend/patches/layers/base_layer_patch.py
Normal file
22
invokeai/backend/patches/layers/base_layer_patch.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseLayerPatch(ABC):
|
||||
@abstractmethod
|
||||
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
||||
"""Get the parameter residual updates that should be applied to the original parameters. Parameters omitted
|
||||
from the returned dict are not updated.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
"""Move all internal tensors to the specified device and dtype."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def calc_size(self) -> int:
|
||||
"""Calculate the total size of all internal tensors in bytes."""
|
||||
...
|
||||
@@ -2,8 +2,8 @@ from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class ConcatenatedLoRALayer(LoRALayerBase):
|
||||
@@ -20,7 +20,7 @@ class ConcatenatedLoRALayer(LoRALayerBase):
|
||||
self.lora_layers = lora_layers
|
||||
self.concat_axis = concat_axis
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
@@ -30,7 +30,7 @@ class ConcatenatedLoRALayer(LoRALayerBase):
|
||||
layer_weights = [lora_layer.get_weight(None) * lora_layer.scale() for lora_layer in self.lora_layers] # pyright: ignore[reportArgumentType]
|
||||
return torch.cat(layer_weights, dim=self.concat_axis)
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]:
|
||||
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
|
||||
# require this value, we will need to implement chunking of the original bias tensor here.
|
||||
# Note that we must apply the sub-layer scales here.
|
||||
19
invokeai/backend/patches/layers/flux_control_lora_layer.py
Normal file
19
invokeai/backend/patches/layers/flux_control_lora_layer.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
class FluxControlLoRALayer(LoRALayer):
|
||||
"""A special case of LoRALayer for use with FLUX Control LoRAs that pads the target parameter with zeros if the
|
||||
shapes don't match.
|
||||
"""
|
||||
|
||||
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
||||
"""This overrides the base class behavior to skip the reshaping step."""
|
||||
scale = self.scale()
|
||||
params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)}
|
||||
bias = self.get_bias(orig_parameters.get("bias", None))
|
||||
if bias is not None:
|
||||
params["bias"] = bias * (weight * scale)
|
||||
|
||||
return params
|
||||
@@ -2,7 +2,7 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class FullLayer(LoRALayerBase):
|
||||
cls.warn_on_unhandled_keys(values=values, handled_keys={"diff", "diff_b"})
|
||||
return layer
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
@@ -2,7 +2,7 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
@@ -16,7 +16,7 @@ class IA3Layer(LoRALayerBase):
|
||||
self.weight = weight
|
||||
self.on_input = on_input
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -2,7 +2,7 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class LoHALayer(LoRALayerBase):
|
||||
self.t2 = t2
|
||||
assert (self.t1 is None) == (self.t2 is None)
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
return self.w1_b.shape[0]
|
||||
|
||||
@classmethod
|
||||
@@ -2,7 +2,7 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
assert (self.w2 is None) != (self.w2_a is None)
|
||||
assert (self.w2_a is None) == (self.w2_b is None)
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
if self.w1_b is not None:
|
||||
return self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
@@ -2,7 +2,7 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
return layer
|
||||
|
||||
def rank(self) -> int:
|
||||
def _rank(self) -> int:
|
||||
return self.down.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1,12 +1,13 @@
|
||||
from typing import Dict, Optional, Set
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
class LoRALayerBase(BaseLayerPatch):
|
||||
"""Base class for all LoRA-like patching layers."""
|
||||
|
||||
# Note: It is tempting to make this a torch.nn.Module sub-class and make all tensors 'torch.nn.Parameter's. Then we
|
||||
@@ -23,6 +24,7 @@ class LoRALayerBase:
|
||||
def _parse_bias(
|
||||
cls, bias_indices: torch.Tensor | None, bias_values: torch.Tensor | None, bias_size: torch.Tensor | None
|
||||
) -> torch.Tensor | None:
|
||||
"""Helper function to parse a bias tensor from a state dict in LyCORIS format."""
|
||||
assert (bias_indices is None) == (bias_values is None) == (bias_size is None)
|
||||
|
||||
bias = None
|
||||
@@ -37,11 +39,14 @@ class LoRALayerBase:
|
||||
) -> float | None:
|
||||
return alpha.item() if alpha is not None else None
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
"""Return the rank of the LoRA-like layer. Or None if the layer does not have a rank. This value is used to
|
||||
calculate the scale.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def scale(self) -> float:
|
||||
rank = self.rank()
|
||||
rank = self._rank()
|
||||
if self._alpha is None or rank is None:
|
||||
return 1.0
|
||||
return self._alpha / rank
|
||||
@@ -49,18 +54,26 @@ class LoRALayerBase:
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
params = {"weight": self.get_weight(orig_module.weight)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
||||
scale = self.scale()
|
||||
params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)}
|
||||
bias = self.get_bias(orig_parameters.get("bias", None))
|
||||
if bias is not None:
|
||||
params["bias"] = bias
|
||||
params["bias"] = bias * (weight * scale)
|
||||
|
||||
# Reshape all params to match the original module's shape.
|
||||
for param_name, param_weight in params.items():
|
||||
orig_param = orig_parameters[param_name]
|
||||
if param_weight.shape != orig_param.shape:
|
||||
params[param_name] = param_weight.reshape(orig_param.shape)
|
||||
|
||||
return params
|
||||
|
||||
@classmethod
|
||||
def warn_on_unhandled_keys(cls, values: Dict[str, torch.Tensor], handled_keys: Set[str]):
|
||||
def warn_on_unhandled_keys(cls, values: dict[str, torch.Tensor], handled_keys: set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
unknown_keys = set(values.keys()) - handled_keys
|
||||
if unknown_keys:
|
||||
@@ -2,7 +2,7 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class NormLayer(LoRALayerBase):
|
||||
cls.warn_on_unhandled_keys(values, {"w_norm", "b_norm"})
|
||||
return layer
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
27
invokeai/backend/patches/layers/set_parameter_layer.py
Normal file
27
invokeai/backend/patches/layers/set_parameter_layer.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class SetParameterLayer(BaseLayerPatch):
|
||||
"""A layer that sets a single parameter to a new target value.
|
||||
(The diff between the target value and current value is calculated internally.)
|
||||
"""
|
||||
|
||||
def __init__(self, param_name: str, weight: torch.Tensor):
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
self.param_name = param_name
|
||||
|
||||
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
||||
# Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX
|
||||
# Control LoRA implementation.
|
||||
diff = self.weight - orig_parameters[self.param_name]
|
||||
return {self.param_name: diff}
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return calc_tensor_size(self.weight)
|
||||
@@ -2,16 +2,16 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.full_layer import FullLayer
|
||||
from invokeai.backend.patches.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.patches.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.patches.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.layers.norm_layer import NormLayer
|
||||
|
||||
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseLayerPatch:
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
|
||||
# Example keys:
|
||||
# guidance_in.in_layer.lora_B.bias
|
||||
# single_blocks.0.linear1.lora_A.weight
|
||||
# double_blocks.0.img_attn.norm.key_norm.scale
|
||||
FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(\w+\.)+(lora_A\.weight|lora_B\.weight|lora_B\.bias|scale)"
|
||||
|
||||
|
||||
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
|
||||
all_keys_match = all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, str(k)) for k in state_dict.keys())
|
||||
|
||||
# Check the shape of the img_in weight, because this layer shape is modified by FLUX control LoRAs.
|
||||
lora_a_weight = state_dict.get("img_in.lora_A.weight", None)
|
||||
lora_b_bias = state_dict.get("img_in.lora_B.bias", None)
|
||||
lora_b_weight = state_dict.get("img_in.lora_B.weight", None)
|
||||
|
||||
return (
|
||||
all_keys_match
|
||||
and lora_a_weight is not None
|
||||
and lora_b_bias is not None
|
||||
and lora_b_weight is not None
|
||||
and lora_a_weight.shape[1] == 128
|
||||
and lora_b_weight.shape[0] == 3072
|
||||
and lora_b_bias.shape[0] == 3072
|
||||
)
|
||||
|
||||
|
||||
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
key_props = key.split(".")
|
||||
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
|
||||
layer_name = ".".join(key_props[:layer_prop_size])
|
||||
param_name = ".".join(key_props[layer_prop_size:])
|
||||
if layer_name not in grouped_state_dict:
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
|
||||
if layer_key == "img_in":
|
||||
# img_in is a special case because it changes the shape of the original weight.
|
||||
layers[prefixed_key] = FluxControlLoRALayer(
|
||||
layer_state_dict["lora_B.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"],
|
||||
)
|
||||
elif all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
|
||||
layers[prefixed_key] = LoRALayer(
|
||||
layer_state_dict["lora_B.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"],
|
||||
)
|
||||
elif "scale" in layer_state_dict:
|
||||
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
|
||||
else:
|
||||
raise ValueError(f"{layer_key} not expected")
|
||||
|
||||
return ModelPatchRaw(layers=layers)
|
||||
@@ -2,11 +2,11 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
|
||||
@@ -30,7 +30,9 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
|
||||
return all_keys_in_peft_format and all_expected_keys_present
|
||||
|
||||
|
||||
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> LoRAModelRaw:
|
||||
def lora_model_from_flux_diffusers_state_dict(
|
||||
state_dict: Dict[str, torch.Tensor], alpha: float | None
|
||||
) -> ModelPatchRaw:
|
||||
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
|
||||
|
||||
This function is based on:
|
||||
@@ -49,7 +51,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
mlp_ratio = 4.0
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
|
||||
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
|
||||
if src_key in grouped_state_dict:
|
||||
@@ -215,7 +217,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
|
||||
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
|
||||
|
||||
return LoRAModelRaw(layers=layers_with_prefix)
|
||||
return ModelPatchRaw(layers=layers_with_prefix)
|
||||
|
||||
|
||||
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user