Compare commits

..

3 Commits

Author SHA1 Message Date
Brandon Rising
bbb48c5475 Gitignore aider 2024-05-28 15:42:34 -04:00
Brandon Rising
e68c49167a aider: Fixed CPU memory issues by deleting large objects and forcing garbage collection.
# Aider chat conversation:

USER: `invokeai/backend/model_manager/convert_ckpt_to_diffusers.py` Im having issues with cpu memory, not gpu memory

ASSISTANT: To address CPU memory issues, we can make the following changes:

1. Use `del` to explicitly delete large objects when they are no longer needed.
2. Use `gc.collect()` to force garbage collection after deleting large objects.

Here are the *SEARCH/REPLACE* blocks:

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    vae = AutoencoderKL(**vae_config)
    with torch.no_grad():
        vae.load_state_dict(converted_vae_checkpoint)
        vae.to(precision)
        torch.cuda.empty_cache()
=======
    vae = AutoencoderKL(**vae_config)
    with torch.no_grad():
        vae.load_state_dict(converted_vae_checkpoint)
        del converted_vae_checkpoint  # Free memory
        import gc
        gc.collect()
        vae.to(precision)
>>>>>>> REPLACE
```

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
    with torch.no_grad():
        pipe = pipe.to(precision)
        torch.cuda.empty_cache()
=======
    pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
    with torch.no_grad():
        del kwargs  # Free memory
        import gc
        gc.collect()
        pipe = pipe.to(precision)
>>>>>>> REPLACE
```

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
    with torch.no_grad():
        pipe = pipe.to(precision)
        torch.cuda.empty_cache()
=======
    pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
    with torch.no_grad():
        del kwargs  # Free memory
        import gc
        gc.collect()
        pipe = pipe.to(precision)
>>>>>>> REPLACE
```
2024-05-28 15:05:36 -04:00
Brandon Rising
46950a9bd0 aider: Improved memory usage efficiency by avoiding intermediate value storage and clearing the cache.
# Aider chat conversation:

USER: `invokeai/backend/model_manager/convert_ckpt_to_diffusers.py` Could you improve the memory usage efficiency of this code?

ASSISTANT: To improve the memory usage efficiency of the code, we can make the following changes:

1. Use `with torch.no_grad()` to avoid storing intermediate values for backpropagation.
2. Move the model to the desired device only when necessary.
3. Clear the cache after moving the model to the desired device.

Here are the *SEARCH/REPLACE* blocks:

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    vae.load_state_dict(converted_vae_checkpoint)
    vae.to(precision)
=======
    with torch.no_grad():
        vae.load_state_dict(converted_vae_checkpoint)
        vae.to(precision)
        torch.cuda.empty_cache()
>>>>>>> REPLACE
```

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    pipe = pipe.to(precision)
=======
    with torch.no_grad():
        pipe = pipe.to(precision)
        torch.cuda.empty_cache()
>>>>>>> REPLACE
```

invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
```python
<<<<<<< SEARCH
    pipe = pipe.to(precision)
=======
    with torch.no_grad():
        pipe = pipe.to(precision)
        torch.cuda.empty_cache()
>>>>>>> REPLACE
```
2024-05-28 15:03:43 -04:00
343 changed files with 10309 additions and 17643 deletions

View File

@@ -9,9 +9,9 @@ runs:
node-version: '18' node-version: '18'
- name: setup pnpm - name: setup pnpm
uses: pnpm/action-setup@v4 uses: pnpm/action-setup@v2
with: with:
version: 8.15.6 version: 8
run_install: false run_install: false
- name: get pnpm store directory - name: get pnpm store directory

View File

@@ -8,7 +8,7 @@
## QA Instructions ## QA Instructions
<!--WHEN APPLICABLE: Describe how you have tested the changes in this PR. Provide enough detail that a reviewer can reproduce your tests.--> <!--WHEN APPLICABLE: Describe how we can test the changes in this PR.-->
## Merge Plan ## Merge Plan

1
.gitignore vendored
View File

@@ -188,3 +188,4 @@ installer/install.sh
installer/update.bat installer/update.bat
installer/update.sh installer/update.sh
installer/InvokeAI-Installer/ installer/InvokeAI-Installer/
.aider*

View File

@@ -18,7 +18,6 @@ help:
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema" @echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version" @echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)" @echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
# Runs ruff, fixing any safely-fixable errors and formatting # Runs ruff, fixing any safely-fixable errors and formatting
ruff: ruff:
@@ -71,6 +70,3 @@ installer-zip:
tag-release: tag-release:
cd installer && ./tag_release.sh cd installer && ./tag_release.sh
# Generate the OpenAPI Schema for the app
openapi:
python scripts/generate_openapi_schema.py

View File

@@ -12,24 +12,12 @@
Invoke is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. Invoke offers an industry leading web-based UI, and serves as the foundation for multiple commercial products. Invoke is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. Invoke offers an industry leading web-based UI, and serves as the foundation for multiple commercial products.
Invoke is available in two editions: [Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs]
| **Community Edition** | **Professional Edition** |
|----------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
| **For users looking for a locally installed, self-hosted and self-managed service** | **For users or teams looking for a cloud-hosted, fully managed service** |
| - Free to use under a commercially-friendly license | - Monthly subscription fee with three different plan levels |
| - Download and install on compatible hardware | - Offers additional benefits, including multi-user support, improved model training, and more |
| - Includes all core studio features: generate, refine, iterate on images, and build workflows | - Hosted in the cloud for easy, secure model access and scalability |
| Quick Start -> [Installation and Updates][installation docs] | More Information -> [www.invoke.com/pricing](https://www.invoke.com/pricing) |
<div align="center">
![Highlighted Features - Canvas and Workflows](https://github.com/invoke-ai/InvokeAI/assets/31807370/708f7a82-084f-4860-bfbe-e2588c53548d) ![Highlighted Features - Canvas and Workflows](https://github.com/invoke-ai/InvokeAI/assets/31807370/708f7a82-084f-4860-bfbe-e2588c53548d)
# Documentation
| **Quick Links** |
|----------------------------------------------------------------------------------------------------------------------------|
| [Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs] |
</div> </div>
## Quick Start ## Quick Start
@@ -49,33 +37,6 @@ Invoke is available in two editions:
More detail, including hardware requirements and manual install instructions, are available in the [installation documentation][installation docs]. More detail, including hardware requirements and manual install instructions, are available in the [installation documentation][installation docs].
## Docker Container
We publish official container images in Github Container Registry: https://github.com/invoke-ai/InvokeAI/pkgs/container/invokeai. Both CUDA and ROCm images are available. Check the above link for relevant tags.
> [!IMPORTANT]
> Ensure that Docker is set up to use the GPU. Refer to [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] documentation.
### Generate!
Run the container, modifying the command as necessary:
```bash
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
```
Then open `http://localhost:9090` and install some models using the Model Manager tab to begin generating.
For ROCm, add `--device /dev/kfd --device /dev/dri` to the `docker run` command.
### Persist your data
You will likely want to persist your workspace outside of the container. Use the `--volume /home/myuser/invokeai:/invokeai` flag to mount some local directory (using its **absolute** path) to the `/invokeai` path inside the container. Your generated images and models will reside there. You can use this directory with other InvokeAI installations, or switch between runtime directories as needed.
### DIY
Build your own image and customize the environment to match your needs using our `docker-compose` stack. See [README.md](./docker/README.md) in the [docker](./docker) directory.
## Troubleshooting, FAQ and Support ## Troubleshooting, FAQ and Support
Please review our [FAQ][faq] for solutions to common installation problems and other issues. Please review our [FAQ][faq] for solutions to common installation problems and other issues.
@@ -153,5 +114,3 @@ Original portions of the software are Copyright © 2024 by respective contributo
[latest release link]: https://github.com/invoke-ai/InvokeAI/releases/latest [latest release link]: https://github.com/invoke-ai/InvokeAI/releases/latest
[translation status badge]: https://hosted.weblate.org/widgets/invokeai/-/svg-badge.svg [translation status badge]: https://hosted.weblate.org/widgets/invokeai/-/svg-badge.svg
[translation status link]: https://hosted.weblate.org/engage/invokeai/ [translation status link]: https://hosted.weblate.org/engage/invokeai/
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html

View File

@@ -19,9 +19,8 @@
## INVOKEAI_PORT is the port on which the InvokeAI web interface will be available ## INVOKEAI_PORT is the port on which the InvokeAI web interface will be available
# INVOKEAI_PORT=9090 # INVOKEAI_PORT=9090
## GPU_DRIVER can be set to either `cuda` or `rocm` to enable GPU support in the container accordingly. ## GPU_DRIVER can be set to either `nvidia` or `rocm` to enable GPU support in the container accordingly.
# GPU_DRIVER=cuda #| rocm # GPU_DRIVER=nvidia #| rocm
## CONTAINER_UID can be set to the UID of the user on the host system that should own the files in the container. ## CONTAINER_UID can be set to the UID of the user on the host system that should own the files in the container.
## It is usually not necessary to change this. Use `id -u` on the host system to find the UID.
# CONTAINER_UID=1000 # CONTAINER_UID=1000

View File

@@ -1,75 +1,41 @@
# Invoke in Docker # InvokeAI Containerized
- Ensure that Docker can use the GPU on your system All commands should be run within the `docker` directory: `cd docker`
- This documentation assumes Linux, but should work similarly under Windows with WSL2
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
## Quickstart :lightning: ## Quickstart :rocket:
No `docker compose`, no persistence, just a simple one-liner using the official images: On a known working Linux+Docker+CUDA (Nvidia) system, execute `./run.sh` in this directory. It will take a few minutes - depending on your internet speed - to install the core models. Once the application starts up, open `http://localhost:9090` in your browser to Invoke!
**CUDA:** For more configuration options (using an AMD GPU, custom root directory location, etc): read on.
```bash ## Detailed setup
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
```
**ROCm:**
```bash
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
```
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
> [!TIP]
> To persist your data (including downloaded models) outside of the container, add a `--volume/-v` flag to the above command, e.g.: `docker run --volume /some/local/path:/invokeai <...the rest of the command>`
## Customize the container
We ship the `run.sh` script, which is a convenient wrapper around `docker compose` for cases where custom image build args are needed. Alternatively, the familiar `docker compose` commands work just as well.
```bash
cd docker
cp .env.sample .env
# edit .env to your liking if you need to; it is well commented.
./run.sh
```
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
## Docker setup in detail
#### Linux #### Linux
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`) 1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository). 2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version. - The deprecated `docker-compose` (hyphenated) CLI continues to work for now.
3. Ensure docker daemon is able to access the GPU. 3. Ensure docker daemon is able to access the GPU.
- [NVIDIA docs](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) - You may need to install [nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
- [AMD docs](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html)
#### macOS #### macOS
> [!TIP]
> You'll be better off installing Invoke directly on your system, because Docker can not use the GPU on macOS.
If you are still reading:
1. Ensure Docker has at least 16GB RAM 1. Ensure Docker has at least 16GB RAM
2. Enable VirtioFS for file sharing 2. Enable VirtioFS for file sharing
3. Enable `docker compose` V2 support 3. Enable `docker compose` V2 support
This is done via Docker Desktop preferences. This is done via Docker Desktop preferences
### Configure the Invoke Environment ### Configure Invoke environment
1. Make a copy of `.env.sample` and name it `.env` (`cp .env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to the desired location of the InvokeAI runtime directory. It may be an existing directory from a previous installation (post 4.0.0). 1. Make a copy of `.env.sample` and name it `.env` (`cp .env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to:
a. the desired location of the InvokeAI runtime directory, or
b. an existing, v3.0.0 compatible runtime directory.
1. Execute `run.sh` 1. Execute `run.sh`
The image will be built automatically if needed. The image will be built automatically if needed.
The runtime directory (holding models and outputs) will be created in the location specified by `INVOKEAI_ROOT`. The default location is `~/invokeai`. Navigate to the Model Manager tab and install some models before generating. The runtime directory (holding models and outputs) will be created in the location specified by `INVOKEAI_ROOT`. The default location is `~/invokeai`. The runtime directory will be populated with the base configs and models necessary to start generating.
### Use a GPU ### Use a GPU
@@ -77,9 +43,9 @@ The runtime directory (holding models and outputs) will be created in the locati
- WSL2 is *required* for Windows. - WSL2 is *required* for Windows.
- only `x86_64` architecture is supported. - only `x86_64` architecture is supported.
The Docker daemon on the system must be already set up to use the GPU. In case of Linux, this involves installing `nvidia-docker-runtime` and configuring the `nvidia` runtime as default. Steps will be different for AMD. Please see Docker/NVIDIA/AMD documentation for the most up-to-date instructions for using your GPU with Docker. The Docker daemon on the system must be already set up to use the GPU. In case of Linux, this involves installing `nvidia-docker-runtime` and configuring the `nvidia` runtime as default. Steps will be different for AMD. Please see Docker documentation for the most up-to-date instructions for using your GPU with Docker.
To use an AMD GPU, set `GPU_DRIVER=rocm` in your `.env` file before running `./run.sh`. To use an AMD GPU, set `GPU_DRIVER=rocm` in your `.env` file.
## Customize ## Customize
@@ -93,10 +59,10 @@ Values are optional, but setting `INVOKEAI_ROOT` is highly recommended. The defa
INVOKEAI_ROOT=/Volumes/WorkDrive/invokeai INVOKEAI_ROOT=/Volumes/WorkDrive/invokeai
HUGGINGFACE_TOKEN=the_actual_token HUGGINGFACE_TOKEN=the_actual_token
CONTAINER_UID=1000 CONTAINER_UID=1000
GPU_DRIVER=cuda GPU_DRIVER=nvidia
``` ```
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail. Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
## Even More Customizing! ## Even More Customizing!

View File

@@ -1,5 +1,7 @@
# Copyright (c) 2023 Eugene Brodsky https://github.com/ebr # Copyright (c) 2023 Eugene Brodsky https://github.com/ebr
version: '3.8'
x-invokeai: &invokeai x-invokeai: &invokeai
image: "local/invokeai:latest" image: "local/invokeai:latest"
build: build:
@@ -30,7 +32,7 @@ x-invokeai: &invokeai
services: services:
invokeai-cuda: invokeai-nvidia:
<<: *invokeai <<: *invokeai
deploy: deploy:
resources: resources:

View File

@@ -23,18 +23,18 @@ usermod -u ${USER_ID} ${USER} 1>/dev/null
# but it is useful to have the full SSH server e.g. on Runpod. # but it is useful to have the full SSH server e.g. on Runpod.
# (use SCP to copy files to/from the image, etc) # (use SCP to copy files to/from the image, etc)
if [[ -v "PUBLIC_KEY" ]] && [[ ! -d "${HOME}/.ssh" ]]; then if [[ -v "PUBLIC_KEY" ]] && [[ ! -d "${HOME}/.ssh" ]]; then
apt-get update apt-get update
apt-get install -y openssh-server apt-get install -y openssh-server
pushd "$HOME" pushd "$HOME"
mkdir -p .ssh mkdir -p .ssh
echo "${PUBLIC_KEY}" >.ssh/authorized_keys echo "${PUBLIC_KEY}" > .ssh/authorized_keys
chmod -R 700 .ssh chmod -R 700 .ssh
popd popd
service ssh start service ssh start
fi fi
mkdir -p "${INVOKEAI_ROOT}" mkdir -p "${INVOKEAI_ROOT}"
chown --recursive ${USER} "${INVOKEAI_ROOT}" || true chown --recursive ${USER} "${INVOKEAI_ROOT}"
cd "${INVOKEAI_ROOT}" cd "${INVOKEAI_ROOT}"
# Run the CMD as the Container User (not root). # Run the CMD as the Container User (not root).

View File

@@ -8,15 +8,11 @@ run() {
local build_args="" local build_args=""
local profile="" local profile=""
# create .env file if it doesn't exist, otherwise docker compose will fail
touch .env touch .env
# parse .env file for build args
build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) && build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) &&
profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)" profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)"
# default to 'cuda' profile [[ -z "$profile" ]] && profile="nvidia"
[[ -z "$profile" ]] && profile="cuda"
local service_name="invokeai-$profile" local service_name="invokeai-$profile"

View File

@@ -128,8 +128,7 @@ The queue operates on a series of download job objects. These objects
specify the source and destination of the download, and keep track of specify the source and destination of the download, and keep track of
the progress of the download. the progress of the download.
Two job types are defined. `DownloadJob` and The only job type currently implemented is `DownloadJob`, a pydantic object with the
`MultiFileDownloadJob`. The former is a pydantic object with the
following fields: following fields:
| **Field** | **Type** | **Default** | **Description** | | **Field** | **Type** | **Default** | **Description** |
@@ -139,7 +138,7 @@ following fields:
| `dest` | Path | | Where to download to | | `dest` | Path | | Where to download to |
| `access_token` | str | | [optional] string containing authentication token for access | | `access_token` | str | | [optional] string containing authentication token for access |
| `on_start` | Callable | | [optional] callback when the download starts | | `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress | | `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion | | `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs | | `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 | | `id` | int | auto assigned | Job ID, an integer >= 0 |
@@ -191,33 +190,6 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
`error_type` field of "DownloadJobCancelledException". In addition, `error_type` field of "DownloadJobCancelledException". In addition,
the job's `cancelled` property will be set to True. the job's `cancelled` property will be set to True.
The `MultiFileDownloadJob` is used for diffusers model downloads,
which contain multiple files and directories under a common root:
| **Field** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| _Fields passed in at job creation time_ |
| `download_parts` | Set[DownloadJob]| | Component download jobs |
| `dest` | Path | | Where to download to |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
| _Fields updated over the course of the download task_
| `status` | DownloadJobStatus| | Status code |
| `download_path` | Path | | Path to the root of the downloaded files |
| `bytes` | int | 0 | Bytes downloaded so far |
| `total_bytes` | int | 0 | Total size of the file at the remote site |
| `error_type` | str | | String version of the exception that caused an error during download |
| `error` | str | | String version of the traceback associated with an error |
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
Note that the MultiFileDownloadJob does not support the `priority`,
`job_started`, `job_ended` or `content_type` attributes. You can get
these from the individual download jobs in `download_parts`.
### Callbacks ### Callbacks
Download jobs can be associated with a series of callbacks, each with Download jobs can be associated with a series of callbacks, each with
@@ -279,40 +251,11 @@ jobs using `list_jobs()`, fetch a single job by its with
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
with `join()`. with `join()`.
#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error) #### job = queue.download(source, dest, priority, access_token)
Create a new download job and put it on the queue, returning the Create a new download job and put it on the queue, returning the
DownloadJob object. DownloadJob object.
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
This is similar to download(), but instead of taking a single source,
it accepts a `parts` argument consisting of a list of
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
where the URL is the location of the remote file, and the Path is the
destination.
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
consists of a url/path pair. Note that the path *must* be relative.
The method returns a `MultiFileDownloadJob`.
```
from invokeai.backend.model_manager.metadata import RemoteModelFile
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
path='my_model/textencoder/pytorch_model.safetensors'
)
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
path='my_model/vae/diffusers_model.safetensors'
)
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
dest='/tmp/downloads',
on_progress=TqdmProgress().update)
queue.wait_for_job(job)
print(f"The files were downloaded to {job.download_path}")
```
#### jobs = queue.list_jobs() #### jobs = queue.list_jobs()
Return a list of all active and inactive `DownloadJob`s. Return a list of all active and inactive `DownloadJob`s.

View File

@@ -397,25 +397,26 @@ In the event you wish to create a new installer, you may use the
following initialization pattern: following initialization pattern:
``` ```
from invokeai.app.services.config import get_config from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceSQL from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_install import ModelInstallService from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.download import DownloadQueueService from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
config = get_config() config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config.db_path, logger) db = SqliteDatabase(config, logger)
record_store = ModelRecordServiceSQL(db, logger) record_store = ModelRecordServiceSQL(db)
queue = DownloadQueueService() queue = DownloadQueueService()
queue.start() queue.start()
installer = ModelInstallService(app_config=config, installer = ModelInstallService(app_config=config,
record_store=record_store, record_store=record_store,
download_queue=queue download_queue=queue
) )
installer.start() installer.start()
``` ```
@@ -1366,20 +1367,12 @@ the in-memory loaded model:
| `model` | AnyModel | The instantiated model (details below) | | `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM | | `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
### get_model_by_key(key, [submodel]) -> LoadedModel Because the loader can return multiple model types, it is typed to
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
The `get_model_by_key()` method will retrieve the model using its `IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
unique database key. For example: `EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) models. The others are obvious.
`get_model_by_key()` may raise any of the following exceptions:
* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model
### Using the Loaded Model in Inference
`LoadedModel` acts as a context manager. The context loads the model `LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model into the execution device (e.g. VRAM on CUDA systems), locks the model
@@ -1387,33 +1380,17 @@ in the execution device for the duration of the context, and returns
the model. Use it like this: the model. Use it like this:
``` ```
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with loaded_model as vae: with model_info as vae:
image = vae.decode(latents)[0] image = vae.decode(latents)[0]
``` ```
The object returned by the LoadedModel context manager is an `get_model_by_key()` may raise any of the following exceptions:
`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
In addition, you may call `LoadedModel.model_on_device()`, a context
manager that returns a tuple of the model's state dict in CPU and the
model itself in VRAM. It is used to optimize the LoRA patching and
unpatching process:
```
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with loaded_model.model_on_device() as (state_dict, vae):
image = vae.decode(latents)[0]
```
Since not all models have state dicts, the `state_dict` return value
can be None.
* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model
### Emitting model loading events ### Emitting model loading events
When the `context` argument is passed to `load_model_*()`, it will When the `context` argument is passed to `load_model_*()`, it will
@@ -1601,59 +1578,3 @@ This method takes a model key, looks it up using the
`ModelRecordServiceBase` object in `mm.store`, and passes the returned `ModelRecordServiceBase` object in `mm.store`, and passes the returned
model configuration to `load_model_by_config()`. It may raise a model configuration to `load_model_by_config()`. It may raise a
`NotImplementedException`. `NotImplementedException`.
## Invocation Context Model Manager API
Within invocations, the following methods are available from the
`InvocationContext` object:
### context.download_and_cache_model(source) -> Path
This method accepts a `source` of a remote model, downloads and caches
it locally, and then returns a Path to the local model. The source can
be a direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
### context.load_local_model(model_path, [loader]) -> LoadedModel
This method loads a local model from the indicated path, returning a
`LoadedModel`. The optional loader is a Callable that accepts a Path
to the object, and returns a `AnyModel` object. If no loader is
provided, then the method will use `torch.load()` for a .ckpt or .bin
checkpoint file, `safetensors.torch.load_file()` for a safetensors
checkpoint file, or `cls.from_pretrained()` for a directory that looks
like a diffusers directory.
### context.load_remote_model(source, [loader]) -> LoadedModel
This method accepts a `source` of a remote model, downloads and caches
it locally, loads it, and returns a `LoadedModel`. The source can be a
direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors

View File

@@ -154,18 +154,6 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
Check the [configuration docs] for more detail about the settings and how to specify them. Check the [configuration docs] for more detail about the settings and how to specify them.
## `ModuleNotFoundError: No module named 'controlnet_aux'`
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
- Run the Invoke launcher
- Choose the developer console option
- Run this command: `pip cache remove controlnet_aux`
- Close the terminal window
- Download and run the [installer](https://github.com/invoke-ai/InvokeAI/releases/latest), selecting your current install location
## Out of Memory Issues ## Out of Memory Issues
The models are large, VRAM is expensive, and you may find yourself The models are large, VRAM is expensive, and you may find yourself

View File

@@ -4,37 +4,50 @@ title: Installing with Docker
# :fontawesome-brands-docker: Docker # :fontawesome-brands-docker: Docker
!!! warning "macOS users" !!! warning "macOS and AMD GPU Users"
Docker can not access the GPU on macOS, so your generation speeds will be slow. [Install InvokeAI](INSTALLATION.md) instead. We highly recommend to Install InvokeAI locally using [these instructions](INSTALLATION.md),
because Docker containers can not access the GPU on macOS.
!!! warning "AMD GPU Users"
Container support for AMD GPUs has been reported to work by the community, but has not received
extensive testing. Please make sure to set the `GPU_DRIVER=rocm` environment variable (see below), and
use the `build.sh` script to build the image for this to take effect at build time.
!!! tip "Linux and Windows Users" !!! tip "Linux and Windows Users"
Configure Docker to access your machine's GPU. For optimal performance, configure your Docker daemon to access your machine's GPU.
Docker Desktop on Windows [includes GPU support](https://www.docker.com/blog/wsl-2-gpu-support-for-docker-desktop-on-nvidia-gpus/). Docker Desktop on Windows [includes GPU support](https://www.docker.com/blog/wsl-2-gpu-support-for-docker-desktop-on-nvidia-gpus/).
Linux users should follow the [NVIDIA](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) or [AMD](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html) documentation. Linux users should install and configure the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
## Why containers?
They provide a flexible, reliable way to build and deploy InvokeAI.
See [Processes](https://12factor.net/processes) under the Twelve-Factor App
methodology for details on why running applications in such a stateless fashion is important.
The container is configured for CUDA by default, but can be built to support AMD GPUs
by setting the `GPU_DRIVER=rocm` environment variable at Docker image build time.
Developers on Apple silicon (M1/M2/M3): You
[can't access your GPU cores from Docker containers](https://github.com/pytorch/pytorch/issues/81224)
and performance is reduced compared with running it directly on macOS but for
development purposes it's fine. Once you're done with development tasks on your
laptop you can build for the target platform and architecture and deploy to
another environment with NVIDIA GPUs on-premises or in the cloud.
## TL;DR ## TL;DR
Ensure your Docker setup is able to use your GPU. Then: This assumes properly configured Docker on Linux or Windows/WSL2. Read on for detailed customization options.
```bash
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
```
Once the container starts up, open http://localhost:9090 in your browser, install some models, and start generating.
## Build-It-Yourself
All the docker materials are located inside the [docker](https://github.com/invoke-ai/InvokeAI/tree/main/docker) directory in the Git repo.
```bash ```bash
# docker compose commands should be run from the `docker` directory
cd docker cd docker
cp .env.sample .env
docker compose up docker compose up
``` ```
We also ship the `run.sh` convenience script. See the `docker/README.md` file for detailed instructions on how to customize the docker setup to your needs. ## Installation in a Linux container (desktop)
### Prerequisites ### Prerequisites
@@ -45,9 +58,18 @@ Preferences, Resources, Advanced. Increase the CPUs and Memory to avoid this
[Issue](https://github.com/invoke-ai/InvokeAI/issues/342). You may need to [Issue](https://github.com/invoke-ai/InvokeAI/issues/342). You may need to
increase Swap and Disk image size too. increase Swap and Disk image size too.
#### Get a Huggingface-Token
Besides the Docker Agent you will need an Account on
[huggingface.co](https://huggingface.co/join).
After you succesfully registered your account, go to
[huggingface.co/settings/tokens](https://huggingface.co/settings/tokens), create
a token and copy it, since you will need in for the next step.
### Setup ### Setup
Set up your environment variables. In the `docker` directory, make a copy of `.env.sample` and name it `.env`. Make changes as necessary. Set up your environmnent variables. In the `docker` directory, make a copy of `.env.sample` and name it `.env`. Make changes as necessary.
Any environment variables supported by InvokeAI can be set here - please see the [CONFIGURATION](../features/CONFIGURATION.md) for further detail. Any environment variables supported by InvokeAI can be set here - please see the [CONFIGURATION](../features/CONFIGURATION.md) for further detail.
@@ -81,9 +103,10 @@ Once the container starts up (and configures the InvokeAI root directory if this
## Troubleshooting / FAQ ## Troubleshooting / FAQ
- Q: I am running on Windows under WSL2, and am seeing a "no such file or directory" error. - Q: I am running on Windows under WSL2, and am seeing a "no such file or directory" error.
- A: Your `docker-entrypoint.sh` might have has Windows (CRLF) line endings, depending how you cloned the repository. - A: Your `docker-entrypoint.sh` file likely has Windows (CRLF) as opposed to Unix (LF) line endings,
To solve this, change the line endings in the `docker-entrypoint.sh` file to `LF`. You can do this in VSCode and you may have cloned this repository before the issue was fixed. To solve this, please change
the line endings in the `docker-entrypoint.sh` file to `LF`. You can do this in VSCode
(`Ctrl+P` and search for "line endings"), or by using the `dos2unix` utility in WSL. (`Ctrl+P` and search for "line endings"), or by using the `dos2unix` utility in WSL.
Finally, you may delete `docker-entrypoint.sh` followed by `git pull; git checkout docker/docker-entrypoint.sh` Finally, you may delete `docker-entrypoint.sh` followed by `git pull; git checkout docker/docker-entrypoint.sh`
to reset the file to its most recent version. to reset the file to its most recent version.
For more information on this issue, see [Docker Desktop documentation](https://docs.docker.com/desktop/troubleshoot/topics/#avoid-unexpected-syntax-errors-use-unix-style-line-endings-for-files-in-containers) For more information on this issue, please see the [Docker Desktop documentation](https://docs.docker.com/desktop/troubleshoot/topics/#avoid-unexpected-syntax-errors-use-unix-style-line-endings-for-files-in-containers)

View File

@@ -13,7 +13,7 @@ echo 2. Open the developer console
echo 3. Command-line help echo 3. Command-line help
echo Q - Quit echo Q - Quit
echo. echo.
echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.
echo. echo.
set /P choice="Please enter 1-4, Q: [1] " set /P choice="Please enter 1-4, Q: [1] "
if not defined choice set choice=1 if not defined choice set choice=1

View File

@@ -4,39 +4,37 @@ from logging import Logger
import torch import torch
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
from invokeai.app.services.board_images.board_images_default import BoardImagesService
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.boards.boards_default import BoardService
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_default import DownloadQueueService
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
from invokeai.app.services.names.names_default import SimpleNameService
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.session_processor.session_processor_default import (
DefaultSessionProcessor,
DefaultSessionRunner,
)
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__ from invokeai.version.invokeai_version import __version__
from ..services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
from ..services.board_images.board_images_default import BoardImagesService
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.events.events_fastapievents import FastAPIEventService
from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
# TODO: is there a better way to achieve this? # TODO: is there a better way to achieve this?
def check_internet() -> bool: def check_internet() -> bool:
@@ -95,11 +93,11 @@ class ApiDependencies:
conditioning = ObjectSerializerForwardCache( conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
) )
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events) download_queue_service = DownloadQueueService(event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager( model_manager = ModelManagerService.build_model_manager(
app_config=configuration, app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db, logger=logger), model_record_service=ModelRecordServiceSQL(db=db),
download_queue=download_queue_service, download_queue=download_queue_service,
events=events, events=events,
) )

View File

@@ -10,13 +10,14 @@ from fastapi import Body
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.upscale import ESRGAN_MODELS from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
from invokeai.backend.util.logging import logging from invokeai.backend.util.logging import logging
from invokeai.version import __version__ from invokeai.version import __version__
from ..dependencies import ApiDependencies
class LogLevel(int, Enum): class LogLevel(int, Enum):
NotSet = logging.NOTSET NotSet = logging.NOTSET

View File

@@ -2,7 +2,7 @@ from fastapi import Body, HTTPException
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies from ..dependencies import ApiDependencies
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"]) board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])

View File

@@ -4,11 +4,12 @@ from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.board_records.board_records_common import BoardChanges from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from ..dependencies import ApiDependencies
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"]) boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
@@ -31,7 +32,6 @@ class DeleteBoardResult(BaseModel):
) )
async def create_board( async def create_board(
board_name: str = Query(description="The name of the board to create"), board_name: str = Query(description="The name of the board to create"),
is_private: bool = Query(default=False, description="Whether the board is private"),
) -> BoardDTO: ) -> BoardDTO:
"""Creates a board""" """Creates a board"""
try: try:
@@ -118,13 +118,15 @@ async def list_boards(
all: Optional[bool] = Query(default=None, description="Whether to list all boards"), all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"), offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query(default=None, description="The number of boards per page"), limit: Optional[int] = Query(default=None, description="The number of boards per page"),
include_archived: bool = Query(default=False, description="Whether or not to include archived boards in list"),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]: ) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards""" """Gets a list of boards"""
if all: if all:
return ApiDependencies.invoker.services.boards.get_all(include_archived) return ApiDependencies.invoker.services.boards.get_all()
elif offset is not None and limit is not None: elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(offset, limit, include_archived) return ApiDependencies.invoker.services.boards.get_many(
offset,
limit,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,

View File

@@ -8,12 +8,13 @@ from fastapi.routing import APIRouter
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.download import ( from invokeai.app.services.download import (
DownloadJob, DownloadJob,
UnknownJobIDException, UnknownJobIDException,
) )
from ..dependencies import ApiDependencies
download_queue_router = APIRouter(prefix="/v1/download_queue", tags=["download_queue"]) download_queue_router = APIRouter(prefix="/v1/download_queue", tags=["download_queue"])

View File

@@ -8,16 +8,12 @@ from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field, JsonValue from pydantic import BaseModel, Field, JsonValue
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.fields import MetadataField from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import ( from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
ImageCategory,
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"]) images_router = APIRouter(prefix="/v1/images", tags=["images"])
@@ -233,14 +229,21 @@ async def get_image_workflow(
) )
async def get_image_full( async def get_image_full(
image_name: str = Path(description="The name of full-resolution image file to get"), image_name: str = Path(description="The name of full-resolution image file to get"),
) -> Response: ) -> FileResponse:
"""Gets a full-resolution image file""" """Gets a full-resolution image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path(image_name) path = ApiDependencies.invoker.services.images.get_path(image_name)
with open(path, "rb") as f:
content = f.read() if not ApiDependencies.invoker.services.images.validate_path(path):
response = Response(content, media_type="image/png") raise HTTPException(status_code=404)
response = FileResponse(
path,
media_type="image/png",
filename=image_name,
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response return response
except Exception: except Exception:
@@ -261,14 +264,15 @@ async def get_image_full(
) )
async def get_image_thumbnail( async def get_image_thumbnail(
image_name: str = Path(description="The name of thumbnail image file to get"), image_name: str = Path(description="The name of thumbnail image file to get"),
) -> Response: ) -> FileResponse:
"""Gets a thumbnail image file""" """Gets a thumbnail image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True) path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
with open(path, "rb") as f: if not ApiDependencies.invoker.services.images.validate_path(path):
content = f.read() raise HTTPException(status_code=404)
response = Response(content, media_type="image/webp")
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response return response
except Exception: except Exception:
@@ -312,14 +316,16 @@ async def list_image_dtos(
), ),
offset: int = Query(default=0, description="The page offset"), offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"), limit: int = Query(default=10, description="The number of images per page"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of image DTOs""" """Gets a list of image DTOs"""
image_dtos = ApiDependencies.invoker.services.images.get_many( image_dtos = ApiDependencies.invoker.services.images.get_many(
offset, limit, starred_first, order_dir, image_origin, categories, is_intermediate, board_id, search_term offset,
limit,
image_origin,
categories,
is_intermediate,
board_id,
) )
return image_dtos return image_dtos

View File

@@ -3,23 +3,23 @@
import io import io
import pathlib import pathlib
import shutil
import traceback import traceback
from copy import deepcopy from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse from fastapi.responses import FileResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException, InvalidModelException,
ModelRecordChanges, ModelRecordChanges,
UnknownModelException, UnknownModelException,
@@ -30,12 +30,15 @@ from invokeai.backend.model_manager.config import (
MainCheckpointConfig, MainCheckpointConfig,
ModelFormat, ModelFormat,
ModelType, ModelType,
SubModelType,
) )
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.starter_models import STARTER_MODELS, StarterModel, StarterModelWithoutDependencies from invokeai.backend.model_manager.starter_models import STARTER_MODELS, StarterModel, StarterModelWithoutDependencies
from ..dependencies import ApiDependencies
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"]) model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
# images are immutable; set a high max-age # images are immutable; set a high max-age
@@ -171,6 +174,18 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
# @model_manager_router.get("/summary", operation_id="list_model_summary")
# async def list_model_summary(
# page: int = Query(default=0, description="The page to get"),
# per_page: int = Query(default=10, description="The number of models per page"),
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
# ) -> PaginatedResults[ModelSummary]:
# """Gets a page of model summary data."""
# record_store = ApiDependencies.invoker.services.model_manager.store
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
# return results
class FoundModel(BaseModel): class FoundModel(BaseModel):
path: str = Field(description="Path to the model") path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed") is_installed: bool = Field(description="Whether or not the model is already installed")
@@ -487,133 +502,6 @@ async def install_model(
return result return result
@model_manager_router.get(
"/install/huggingface",
operation_id="install_hugging_face_model",
responses={
201: {"description": "The model is being installed"},
400: {"description": "Bad request"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_class=HTMLResponse,
)
async def install_hugging_face_model(
source: str = Query(description="HuggingFace repo_id to install"),
) -> HTMLResponse:
"""Install a Hugging Face model using a string identifier."""
def generate_html(title: str, heading: str, repo_id: str, is_error: bool, message: str | None = "") -> str:
if message:
message = f"<p>{message}</p>"
title_class = "error" if is_error else "success"
return f"""
<html>
<head>
<title>{title}</title>
<style>
body {{
text-align: center;
background-color: hsl(220 12% 10% / 1);
font-family: Helvetica, sans-serif;
color: hsl(220 12% 86% / 1);
}}
.repo-id {{
color: hsl(220 12% 68% / 1);
}}
.error {{
color: hsl(0 42% 68% / 1)
}}
.message-box {{
display: inline-block;
border-radius: 5px;
background-color: hsl(220 12% 20% / 1);
padding-inline-end: 30px;
padding: 20px;
padding-inline-start: 30px;
padding-inline-end: 30px;
}}
.container {{
display: flex;
width: 100%;
height: 100%;
align-items: center;
justify-content: center;
}}
a {{
color: inherit
}}
a:visited {{
color: inherit
}}
a:active {{
color: inherit
}}
</style>
</head>
<body style="background-color: hsl(220 12% 10% / 1);">
<div class="container">
<div class="message-box">
<h2 class="{title_class}">{heading}</h2>
{message}
<p class="repo-id">Repo ID: {repo_id}</p>
</div>
</div>
</body>
</html>
"""
try:
metadata = HuggingFaceMetadataFetch().from_id(source)
assert isinstance(metadata, ModelMetadataWithFiles)
except UnknownMetadataException:
title = "Unable to Install Model"
heading = "No HuggingFace repository found with that repo ID."
message = "Ensure the repo ID is correct and try again."
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=400)
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
if metadata.is_diffusers:
installer.heuristic_import(
source=source,
inplace=False,
)
elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1:
installer.heuristic_import(
source=str(metadata.ckpt_urls[0]),
inplace=False,
)
else:
title = "Unable to Install Model"
heading = "This HuggingFace repo has multiple models."
message = "Please use the Model Manager to install this model."
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=200)
title = "Model Install Started"
heading = "Your HuggingFace model is installing now."
message = "You can close this tab and check the Model Manager for installation progress."
return HTMLResponse(content=generate_html(title, heading, source, False, message), status_code=201)
except Exception as e:
logger.error(str(e))
title = "Unable to Install Model"
heading = "There was an problem installing this model."
message = 'Please use the Model Manager directly to install this model. If the issue persists, ask for help on <a href="https://discord.gg/ZmtBAhwWhy">discord</a>.'
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=500)
@model_manager_router.get( @model_manager_router.get(
"/install", "/install",
operation_id="list_model_installs", operation_id="list_model_installs",
@@ -731,36 +619,39 @@ async def convert_model(
logger.error(f"The model with key {key} is not a main checkpoint model.") logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir: # loading the model will convert it into a cached diffusers file
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem try:
converted_model = loader.load_model(model_config) cc_size = loader.convert_cache.max_size
# write the converted file to the convert path if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
raw_model = converted_model.model loader._convert_cache.max_size = 1.0
assert hasattr(raw_model, "save_pretrained") loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
raw_model.save_pretrained(convert_path) finally:
assert convert_path.exists() loader._convert_cache.max_size = cc_size
# temporarily rename the original safetensors file so that there is no naming conflict # Get the path of the converted model from the loader
original_name = model_config.name cache_path = loader.convert_cache.cache_path(key)
model_config.name = f"{original_name}.DELETE" assert cache_path.exists()
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# install the diffusers # temporarily rename the original safetensors file so that there is no naming conflict
try: original_name = model_config.name
new_key = installer.install_path( model_config.name = f"{original_name}.DELETE"
convert_path, changes = ModelRecordChanges(name=model_config.name)
config={ store.update_model(key, changes=changes)
"name": original_name,
"description": model_config.description, # install the diffusers
"hash": model_config.hash, try:
"source": model_config.source, new_key = installer.install_path(
}, cache_path,
) config={
except Exception as e: "name": original_name,
logger.error(str(e)) "description": model_config.description,
store.update_model(key, changes=ModelRecordChanges(name=original_name)) "hash": model_config.hash,
raise HTTPException(status_code=409, detail=str(e)) "source": model_config.source,
},
)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
# Update the model image if the model had one # Update the model image if the model had one
try: try:
@@ -773,8 +664,8 @@ async def convert_model(
# delete the original safetensors file # delete the original safetensors file
installer.delete(key) installer.delete(key)
# delete the temporary directory # delete the cached version
# shutil.rmtree(cache_path) shutil.rmtree(cache_path)
# return the config record for the new diffusers directory # return the config record for the new diffusers directory
new_config = store.get_model(new_key) new_config = store.get_model(new_key)

View File

@@ -4,7 +4,6 @@ from fastapi import Body, Path, Query
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS, QUEUE_ITEM_STATUS,
@@ -20,6 +19,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
) )
from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.pagination import CursorPaginatedResults
from ..dependencies import ApiDependencies
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"]) session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])

View File

@@ -3,7 +3,9 @@ import logging
import mimetypes import mimetypes
import socket import socket
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path from pathlib import Path
from typing import Any
import torch import torch
import uvicorn import uvicorn
@@ -11,18 +13,27 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available from torch.backends.mps import is_available as is_mps_available
# for PyCharm: # for PyCharm:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.api.routers import ( from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import (
app_info, app_info,
board_images, board_images,
boards, boards,
@@ -33,11 +44,12 @@ from invokeai.app.api.routers import (
utilities, utilities,
workflows, workflows,
) )
from invokeai.app.api.sockets import SocketIO from .api.sockets import SocketIO
from invokeai.app.services.config.config_default import get_config from .invocations.baseinvocation import (
from invokeai.app.util.custom_openapi import get_openapi_func BaseInvocation,
from invokeai.backend.util.devices import TorchDevice UIConfigBase,
from invokeai.backend.util.logging import InvokeAILogger )
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
app_config = get_config() app_config = get_config()
@@ -107,7 +119,84 @@ app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api") app.include_router(workflows.workflows_router, prefix="/api")
app.openapi = get_openapi_func(app)
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = {}
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = models_json_schema(
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
)
for schema_key, output_schema in output_schemas[1]["$defs"].items():
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Some models don't end up in the schemas as standalone definitions
additional_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": {},
"required": [],
}
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"
# Add all event schemas
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
if "$defs" in json_schema:
for schema_key, schema in json_schema["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema
del json_schema["$defs"]
openapi_schema["components"]["schemas"][event.__name__] = json_schema
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
@app.get("/docs", include_in_schema=False) @app.get("/docs", include_in_schema=False)
@@ -161,7 +250,6 @@ def invoke_api() -> None:
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon! # Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker # https://github.com/WaylonWalker
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
if s.connect_ex(("localhost", port)) == 0: if s.connect_ex(("localhost", port)) == 0:
return find_port(port=port + 1) return find_port(port=port + 1)
else: else:

View File

@@ -40,7 +40,7 @@ from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
@@ -98,13 +98,11 @@ class BaseInvocationOutput(BaseModel):
_output_classes: ClassVar[set[BaseInvocationOutput]] = set() _output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod @classmethod
def register_output(cls, output: BaseInvocationOutput) -> None: def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output.""" """Registers an invocation output."""
cls._output_classes.add(output) cls._output_classes.add(output)
cls._typeadapter_needs_update = True
@classmethod @classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]: def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
@@ -114,12 +112,11 @@ class BaseInvocationOutput(BaseModel):
@classmethod @classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]: def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types.""" """Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter or cls._typeadapter_needs_update: if not cls._typeadapter:
AnyInvocationOutput = TypeAliasType( InvocationOutputsUnion = TypeAliasType(
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] "InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
) )
cls._typeadapter = TypeAdapter(AnyInvocationOutput) cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
cls._typeadapter_needs_update = False
return cls._typeadapter return cls._typeadapter
@classmethod @classmethod
@@ -128,13 +125,12 @@ class BaseInvocationOutput(BaseModel):
return (i.get_type() for i in BaseInvocationOutput.get_outputs()) return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" """Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
# Because we use a pydantic Literal field with default value for the invocation type, # Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually. # it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = [] schema["required"] = []
schema["class"] = "output"
schema["required"].extend(["type"]) schema["required"].extend(["type"])
@classmethod @classmethod
@@ -171,7 +167,6 @@ class BaseInvocation(ABC, BaseModel):
_invocation_classes: ClassVar[set[BaseInvocation]] = set() _invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod @classmethod
def get_type(cls) -> str: def get_type(cls) -> str:
@@ -182,17 +177,15 @@ class BaseInvocation(ABC, BaseModel):
def register_invocation(cls, invocation: BaseInvocation) -> None: def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation.""" """Registers an invocation."""
cls._invocation_classes.add(invocation) cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True
@classmethod @classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]: def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types.""" """Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter or cls._typeadapter_needs_update: if not cls._typeadapter:
AnyInvocation = TypeAliasType( InvocationsUnion = TypeAliasType(
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] "InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
) )
cls._typeadapter = TypeAdapter(AnyInvocation) cls._typeadapter = TypeAdapter(InvocationsUnion)
cls._typeadapter_needs_update = False
return cls._typeadapter return cls._typeadapter
@classmethod @classmethod
@@ -228,7 +221,7 @@ class BaseInvocation(ABC, BaseModel):
return signature(cls.invoke).return_annotation return signature(cls.invoke).return_annotation
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema.""" """Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None)) uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None: if uiconfig is not None:
@@ -244,7 +237,6 @@ class BaseInvocation(ABC, BaseModel):
schema["version"] = uiconfig.version schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = [] schema["required"] = []
schema["class"] = "invocation"
schema["required"].extend(["type", "id"]) schema["required"].extend(["type", "id"])
@abstractmethod @abstractmethod
@@ -318,7 +310,7 @@ class BaseInvocation(ABC, BaseModel):
protected_namespaces=(), protected_namespaces=(),
validate_assignment=True, validate_assignment=True,
json_schema_extra=json_schema_extra, json_schema_extra=json_schema_extra,
json_schema_serialization_defaults_required=False, json_schema_serialization_defaults_required=True,
coerce_numbers_to_str=True, coerce_numbers_to_str=True,
) )

View File

@@ -1,98 +0,0 @@
from typing import Any, Union
import numpy as np
import numpy.typing as npt
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
@invocation(
"lblend",
title="Blend Latents",
tags=["latents", "blend"],
category="latents",
version="1.0.3",
)
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""
latents_a: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
latents_b: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.tensors.load(self.latents_a.latents_name)
latents_b = context.tensors.load(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
device = TorchDevice.choose_torch_device()
def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
v0: Union[torch.Tensor, npt.NDArray[Any]],
v1: Union[torch.Tensor, npt.NDArray[Any]],
DOT_THRESHOLD: float = 0.9995,
) -> Union[torch.Tensor, npt.NDArray[Any]]:
"""
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
"""
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
return v2_torch
else:
assert isinstance(v2, np.ndarray)
return v2
# blend
bl = slerp(self.alpha, latents_a, latents_b)
assert isinstance(bl, torch.Tensor)
blended_latents: torch.Tensor = bl # for type checking convenience
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)

View File

@@ -4,12 +4,13 @@
import numpy as np import numpy as np
from pydantic import ValidationInfo, field_validator from pydantic import ValidationInfo, field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField
from invokeai.app.invocations.primitives import IntegerCollectionOutput from invokeai.app.invocations.primitives import IntegerCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX from invokeai.app.util.misc import SEED_MAX
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField
@invocation( @invocation(
"range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0" "range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0"

View File

@@ -5,7 +5,6 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
ConditioningField, ConditioningField,
FieldDescriptions, FieldDescriptions,
@@ -15,7 +14,6 @@ from invokeai.app.invocations.fields import (
TensorField, TensorField,
UIComponent, UIComponent,
) )
from invokeai.app.invocations.model import CLIPField
from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list from invokeai.app.util.ti_utils import generate_ti_list
@@ -28,6 +26,9 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
) )
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField
# unconditioned: Optional[torch.Tensor] # unconditioned: Optional[torch.Tensor]
@@ -80,13 +81,9 @@ class CompelInvocation(BaseInvocation):
with ( with (
# apply all patches while the model is on the target device # apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (model_state_dict, text_encoder), text_encoder_info as text_encoder,
tokenizer_info as tokenizer, tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder( ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
text_encoder,
loras=_lora_loader(),
model_state_dict=model_state_dict,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
@@ -175,14 +172,9 @@ class SDXLPromptInvocationBase:
with ( with (
# apply all patches while the model is on the target device # apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (state_dict, text_encoder), text_encoder_info as text_encoder,
tokenizer_info as tokenizer, tokenizer_info as tokenizer,
ModelPatcher.apply_lora( ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
text_encoder,
loras=_lora_loader(),
prefix=lora_prefix,
model_state_dict=state_dict,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (

View File

@@ -1,6 +1,6 @@
from typing import Literal from typing import Literal
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
LATENT_SCALE_FACTOR = 8 LATENT_SCALE_FACTOR = 8
""" """
@@ -10,7 +10,8 @@ factor is hard-coded to a literal '8' rather than using this constant.
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
""" """
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
"""A literal type representing the valid scheduler names."""
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke""" """A literal type for PIL image modes supported by Invoke"""
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()

View File

@@ -2,7 +2,6 @@
# initial implementation by Gregg Helt, 2023 # initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float from builtins import bool, float
from pathlib import Path
from typing import Dict, List, Literal, Union from typing import Dict, List, Literal, Union
import cv2 import cv2
@@ -22,13 +21,6 @@ from controlnet_aux.util import HWC3, ade_palette
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
ImageField, ImageField,
@@ -44,13 +36,14 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.canny import get_canny_edges from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
class ControlField(BaseModel): class ControlField(BaseModel):
@@ -146,7 +139,6 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return context.images.get_pil(self.image.image_name, "RGB") return context.images.get_pil(self.image.image_name, "RGB")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
raw_image = self.load_image(context) raw_image = self.load_image(context)
# image type should be PIL.PngImagePlugin.PngImageFile ? # image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image) processed_image = self.run_processor(raw_image)
@@ -292,8 +284,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
# depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor( processed_image = midas_processor(
image, image,
@@ -320,7 +311,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor( processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
@@ -339,7 +330,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor( processed_image = mlsd_processor(
image, image,
@@ -362,7 +353,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor( processed_image = pidi_processor(
image, image,
@@ -390,7 +381,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter") w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter") f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
content_shuffle_processor = ContentShuffleDetector() content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor( processed_image = content_shuffle_processor(
image, image,
@@ -414,7 +405,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image""" """Applies Zoe depth processing to image"""
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image) processed_image = zoe_depth_processor(image)
return processed_image return processed_image
@@ -435,7 +426,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
mediapipe_face_processor = MediapipeFaceDetector() mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor( processed_image = mediapipe_face_processor(
image, image,
@@ -463,7 +454,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor( processed_image = leres_processor(
image, image,
@@ -505,8 +496,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img return np_img
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, img):
np_img = np.array(image, dtype=np.uint8) np_img = np.array(img, dtype=np.uint8)
processed_np_image = self.tile_resample( processed_np_image = self.tile_resample(
np_img, np_img,
# res=self.tile_size, # res=self.tile_size,
@@ -529,7 +520,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints" "ybelkada/segment-anything", subfolder="checkpoints"
@@ -575,7 +566,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size) color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image):
np_image = np.array(image, dtype=np.uint8) np_image = np.array(image, dtype=np.uint8)
height, width = np_image.shape[:2] height, width = np_image.shape[:2]
@@ -610,18 +601,12 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
) )
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image):
def loader(model_path: Path): depth_anything_detector = DepthAnythingDetector()
return DepthAnythingDetector.load_model( depth_anything_detector.load_model(model_size=self.model_size)
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
)
with self._context.models.load_remote_model( processed_image = depth_anything_detector(image=image, resolution=self.resolution)
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader return processed_image
) as model:
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image
@invocation( @invocation(
@@ -639,11 +624,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
draw_hands: bool = InputField(default=False) draw_hands: bool = InputField(default=False)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image):
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"]) dw_openpose = DWOpenposeDetector()
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
processed_image = dw_openpose( processed_image = dw_openpose(
image, image,
draw_face=self.draw_face, draw_face=self.draw_face,

View File

@@ -1,80 +0,0 @@
from typing import Optional
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import DenoiseMaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation(
"create_denoise_mask",
title="Create Denoise Mask",
tags=["mask", "denoise"],
category="latents",
version="1.0.2",
)
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=4,
)
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
# if shape is not None:
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
return mask_tensor
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None:
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
else:
image_tensor = None
mask = self.prep_mask_tensor(
context.images.get_pil(self.mask.image_name),
)
if image_tensor is not None:
vae_info = context.models.load(self.vae.vae)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = context.tensors.save(tensor=masked_latents)
else:
masked_latents_name = None
mask_name = context.tensors.save(tensor=mask)
return DenoiseMaskOutput.build(
mask_name=mask_name,
masked_latents_name=masked_latents_name,
gradient=False,
)

View File

@@ -1,138 +0,0 @@
from typing import Literal, Optional
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageFilter
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
ImageField,
Input,
InputField,
OutputField,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation(
"create_gradient_mask",
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.1.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
)
image: Optional[ImageField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] Image",
ui_order=6,
)
unet: Optional[UNetField] = InputField(
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
default=None,
input=Input.Connection,
title="[OPTIONAL] UNet",
ui_order=5,
)
vae: Optional[VAEField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] VAE",
input=Input.Connection,
ui_order=7,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=9,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.edge_radius > 0:
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
else:
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
masked_latents_name = None
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
masked_latents = ImageToLatentsInvocation.vae_encode(
vae_info, self.fp32, self.tiled, masked_image.clone()
)
masked_latents_name = context.tensors.save(tensor=masked_latents)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)

View File

@@ -1,61 +0,0 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
# The Crop Latents node was copied from @skunkworxdark's implementation here:
# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80
@invocation(
"crop_latents",
title="Crop Latents",
tags=["latents", "crop"],
category="latents",
version="1.0.2",
)
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
# Currently, if the class names conflict then 'GET /openapi.json' fails.
class CropLatentsCoreInvocation(BaseInvocation):
"""Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be
divisible by the latent scale factor of 8.
"""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
x: int = InputField(
ge=0,
multiple_of=LATENT_SCALE_FACTOR,
description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
y: int = InputField(
ge=0,
multiple_of=LATENT_SCALE_FACTOR,
description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
width: int = InputField(
ge=1,
multiple_of=LATENT_SCALE_FACTOR,
description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
height: int = InputField(
ge=1,
multiple_of=LATENT_SCALE_FACTOR,
description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
x1 = self.x // LATENT_SCALE_FACTOR
y1 = self.y // LATENT_SCALE_FACTOR
x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
cropped_latents = latents[..., y1:y2, x1:x2]
name = context.tensors.save(tensor=cropped_latents)
return LatentsOutput.build(latents_name=name, latents=cropped_latents)

View File

@@ -5,11 +5,13 @@ import cv2 as cv
import numpy import numpy
from PIL import Image, ImageOps from PIL import Image, ImageOps
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.3.1") @invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.3.1")
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard): class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -1,849 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torchvision
import torchvision.transforms as T
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.adapter import T2IAdapter
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
from diffusers.schedulers.scheduling_tcd import TCDScheduler
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
ConditioningField,
DenoiseMaskField,
FieldDescriptions,
Input,
InputField,
LatentsField,
UIType,
)
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import ModelIdentifierField, UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
T2IAdapterData,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
IPAdapterConditioningInfo,
IPAdapterData,
Range,
SDXLConditioningInfo,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelIdentifierField,
scheduler_name: str,
seed: int,
) -> Scheduler:
"""Load a scheduler and apply some scheduler-specific overrides."""
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {
**scheduler_config,
**scheduler_extra_config, # FIXME
"_backup": scheduler_config,
}
# make dpmpp_sde reproducable(seed can be passed only in initializer)
if scheduler_class is DPMSolverSDEScheduler:
scheduler_config["noise_sampler_seed"] = seed
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, "uses_inpainting_model"):
scheduler.uses_inpainting_model = lambda: False
assert isinstance(scheduler, Scheduler)
return scheduler
@invocation(
"denoise_latents",
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.5.3",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
)
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
)
noise: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
ui_order=3,
)
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: Union[float, List[float]] = InputField(
default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale"
)
denoising_start: float = InputField(
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
ui_order=2,
)
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
default=None,
input=Input.Connection,
ui_order=5,
)
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
description=FieldDescriptions.ip_adapter,
title="IP-Adapter",
default=None,
input=Input.Connection,
ui_order=6,
)
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
description=FieldDescriptions.t2i_adapter,
title="T2I-Adapter",
default=None,
input=Input.Connection,
ui_order=7,
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
latents: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
ui_order=4,
)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.mask,
input=Input.Connection,
ui_order=8,
)
@field_validator("cfg_scale")
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def _get_text_embeddings_and_masks(
cond_list: list[ConditioningField],
context: InvocationContext,
device: torch.device,
dtype: torch.dtype,
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
"""Get the text embeddings and masks from the input conditioning fields."""
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
cond_data = context.conditioning.load(cond.conditioning_name)
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
text_embeddings_masks.append(mask)
return text_embeddings, text_embeddings_masks
@staticmethod
def _preprocess_regional_prompt_mask(
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
Returns:
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
mask = to_standard_float_mask(mask, out_dtype=dtype)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
return resized_mask
@staticmethod
def _concat_regional_text_embeddings(
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
if masks is None:
masks = [None] * len(text_conditionings)
assert len(text_conditionings) == len(masks)
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
cur_text_embedding_len = 0
processed_masks = []
embedding_ranges = []
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
mask = masks[prompt_idx]
if is_sdxl:
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
# global prompt information. In an ideal case, there should be exactly one global prompt without a
# mask, but we don't enforce this.
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
# pretty major breaking change to a popular node, so for now we use this hack.
if pooled_embedding is None or mask is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None or mask is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
if not all_masks_are_none:
embedding_ranges.append(
Range(
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
)
)
processed_masks.append(
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
mask, latent_height, latent_width, dtype=dtype
)
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regions = None
if not all_masks_are_none:
regions = TextConditioningRegions(
masks=torch.cat(processed_masks, dim=1),
ranges=embedding_ranges,
)
if is_sdxl:
return (
SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids),
regions,
)
return BasicConditioningInfo(embeds=text_embedding), regions
@staticmethod
def get_conditioning_data(
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
) -> TextConditioningData:
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
cond_list = positive_conditioning_field
if not isinstance(cond_list, list):
cond_list = [cond_list]
uncond_list = negative_conditioning_field
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
)
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
if isinstance(cfg_scale, list):
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
guidance_scale=cfg_scale,
guidance_rescale_multiplier=cfg_rescale_multiplier,
)
return conditioning_data
@staticmethod
def create_pipeline(
unet: UNet2DConditionModel,
scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline:
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline(
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
@staticmethod
def prep_control_data(
context: InvocationContext,
control_input: ControlField | list[ControlField] | None,
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> list[ControlNetData] | None:
# Normalize control_input to a list.
control_list: list[ControlField]
if isinstance(control_input, ControlField):
control_list = [control_input]
elif isinstance(control_input, list):
control_list = control_input
elif control_input is None:
control_list = []
else:
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
if len(control_list) == 0:
return None
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
_, _, latent_height, latent_width = latents_shape
control_height_resize = latent_height * LATENT_SCALE_FACTOR
control_width_resize = latent_width * LATENT_SCALE_FACTOR
controlnet_data: list[ControlNetData] = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
assert isinstance(control_model, ControlNetModel)
control_image_field = control_info.image
input_image = context.images.get_pil(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
controlnet_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return controlnet_data
def prep_ip_adapter_image_prompts(
self,
context: InvocationContext,
ip_adapters: List[IPAdapterField],
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
image_prompts = []
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
single_ipa_image_fields = [single_ipa_image_fields]
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
)
image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds))
return image_prompts
def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapters: List[IPAdapterField],
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> Optional[List[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data."""
ip_adapter_data_list = []
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
ip_adapters, image_prompts, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
mask_field = single_ip_adapter.mask
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
ip_adapter_data_list.append(
IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=single_ip_adapter.weight,
target_blocks=single_ip_adapter.target_blocks,
begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
mask=mask,
)
)
return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None
def run_t2i_adapters(
self,
context: InvocationContext,
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
latents_shape: list[int],
do_classifier_free_guidance: bool,
) -> Optional[list[T2IAdapterData]]:
if t2i_adapter is None:
return None
# Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
if isinstance(t2i_adapter, T2IAdapterField):
t2i_adapter = [t2i_adapter]
if len(t2i_adapter) == 0:
return None
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1:
max_unet_downscale = 8
elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
max_unet_downscale = 4
else:
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
t2i_adapter_model: T2IAdapter
with t2i_adapter_loaded_model as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor
# Resize the T2I-Adapter input image.
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
# result will match the latent image's dimensions after max_unet_downscale is applied.
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
# T2I-Adapter model.
#
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
# of the same requirements (e.g. preserving binary masks during resize).
t2i_image = prepare_control_image(
image=image,
do_classifier_free_guidance=False,
width=t2i_input_width,
height=t2i_input_height,
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
device=t2i_adapter_model.device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
)
adapter_state = t2i_adapter_model(t2i_image)
if do_classifier_free_guidance:
for idx, value in enumerate(adapter_state):
adapter_state[idx] = torch.cat([value] * 2, dim=0)
t2i_adapter_data.append(
T2IAdapterData(
adapter_state=adapter_state,
weight=t2i_adapter_field.weight,
begin_step_percent=t2i_adapter_field.begin_step_percent,
end_step_percent=t2i_adapter_field.end_step_percent,
)
)
return t2i_adapter_data
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
@staticmethod
def init_scheduler(
scheduler: Union[Scheduler, ConfigMixin],
device: torch.device,
steps: int,
denoising_start: float,
denoising_end: float,
seed: int,
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu")
timesteps = scheduler.timesteps.to(device=device)
else:
scheduler.set_timesteps(steps, device=device)
timesteps = scheduler.timesteps
# skip greater order timesteps
_timesteps = timesteps[:: scheduler.order]
# get start timestep index
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
# get end timestep index
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
# apply order to indexes
t_start_idx *= scheduler.order
t_end_idx *= scheduler.order
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
scheduler_step_kwargs: Dict[str, Any] = {}
scheduler_step_signature = inspect.signature(scheduler.step)
if "generator" in scheduler_step_signature.parameters:
# At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility.
#
# These Invoke-supported schedulers accept a generator as of 2024-06-04:
# - DDIMScheduler
# - DDPMScheduler
# - DPMSolverMultistepScheduler
# - EulerAncestralDiscreteScheduler
# - EulerDiscreteScheduler
# - KDPM2AncestralDiscreteScheduler
# - LCMScheduler
# - TCDScheduler
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
if isinstance(scheduler, TCDScheduler):
scheduler_step_kwargs.update({"eta": 1.0})
return timesteps, init_timestep, scheduler_step_kwargs
def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
if self.denoise_mask is None:
return None, None, False
mask = context.tensors.load(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
else:
masked_latents = torch.where(mask < 0.5, 0.0, latents)
return 1 - mask, masked_latents, self.denoise_mask.gradient
@staticmethod
def prepare_noise_and_latents(
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
) -> Tuple[int, torch.Tensor | None, torch.Tensor]:
"""Depending on the workflow, we expect different combinations of noise and latents to be provided. This
function handles preparing these values accordingly.
Expected workflows:
- Text-to-Image Denoising: `noise` is provided, `latents` is not. `latents` is initialized to zeros.
- Image-to-Image Denoising: `noise` and `latents` are both provided.
- Text-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
- Image-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
NOTE(ryand): I wrote this docstring, but I am not the original author of this code. There may be other workflows
I haven't considered.
"""
noise = None
if noise_field is not None:
noise = context.tensors.load(noise_field.latents_name)
if latents_field is not None:
latents = context.tensors.load(latents_field.latents_name)
elif noise is not None:
latents = torch.zeros_like(noise)
else:
raise ValueError("'latents' or 'noise' must be provided!")
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise ValueError(f"Incompatible 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
seed = 0
if noise_field is not None and noise_field.seed is not None:
seed = noise_field.seed
elif latents_field is not None and latents_field.seed is not None:
seed = latents_field.seed
else:
seed = 0
return seed, noise, latents
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
t2i_adapter_data = self.run_t2i_adapters(
context,
self.t2i_adapter,
latents.shape,
do_classifier_free_guidance=True,
)
ip_adapters: List[IPAdapterField] = []
if self.ip_adapter is not None:
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
if isinstance(self.ip_adapter, list):
ip_adapters = self.ip_adapter
else:
ip_adapters = [self.ip_adapter]
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
# a series of image conditioning embeddings. This is being done here rather than in the
# big model context below in order to use less VRAM on low-VRAM systems.
# The image prompts are then passed to prep_ip_adapter_data().
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
model_state_dict=model_state_dict,
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
if mask is not None:
mask = mask.to(device=unet.device, dtype=unet.dtype)
if masked_latents is not None:
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = self.create_pipeline(unet, scheduler)
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = self.prep_control_data(
context=context,
control_input=self.control,
latents_shape=latents.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapters=ip_adapters,
image_prompts=image_prompts,
exit_stack=exit_stack,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
result_latents = pipeline.latents_from_embeddings(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
mask=mask,
masked_latents=masked_latents,
is_gradient_mask=gradient_mask,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=controlnet_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

View File

@@ -160,7 +160,6 @@ class FieldDescriptions:
fp32 = "Whether or not to use full float32 precision" fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use" precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)" tiled = "Processing using overlapping tiles (reduce memory consumption)"
vae_tile_size = "The tile size for VAE tiling in pixels (image space). If set to 0, the default tile size for the model will be used. Larger tile sizes generally produce better results at the cost of higher memory usage."
detect_res = "Pixel resolution for detection" detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image" image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode" safe_mode = "Whether or not to use safe mode"

View File

@@ -1,65 +0,0 @@
import math
from typing import Tuple
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
from invokeai.app.invocations.model import UNetField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
@invocation_output("ideal_size_output")
class IdealSizeOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
width: int = OutputField(description="The ideal width of the image (in pixels)")
height: int = OutputField(description="The ideal height of the image (in pixels)")
@invocation(
"ideal_size",
title="Ideal Size",
tags=["latents", "math", "ideal_size"],
version="1.0.3",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""
width: int = InputField(default=1024, description="Final image width")
height: int = InputField(default=576, description="Final image height")
unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
multiplier: float = InputField(
default=1.0,
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in "
"initial generation artifacts if too large)",
)
def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.models.get_config(self.unet.unet.key)
aspect = self.width / self.height
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2:
dimension = 768
elif unet_config.base == BaseModelType.StableDiffusionXL:
dimension = 1024
dimension = dimension * self.multiplier
min_dimension = math.floor(dimension * 0.5)
model_area = dimension * dimension # hardcoded for now since all models are trained on square images
if aspect > 1.0:
init_height = max(min_dimension, math.sqrt(model_area / aspect))
init_width = init_height * aspect
else:
init_width = max(min_dimension, math.sqrt(model_area * aspect))
init_height = init_width / aspect
scaled_width, scaled_height = self.trim_to_multiple_of(
math.floor(init_width),
math.floor(init_height),
)
return IdealSizeOutput(width=scaled_width, height=scaled_height)

View File

@@ -6,7 +6,6 @@ import cv2
import numpy import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
ColorField, ColorField,
@@ -22,6 +21,8 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import BaseInvocation, Classification, invocation
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1") @invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1")
class ShowImageInvocation(BaseInvocation): class ShowImageInvocation(BaseInvocation):

View File

@@ -1,143 +0,0 @@
from contextlib import nullcontext
from functools import singledispatchmethod
import einops
import torch
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
@invocation(
"i2l",
title="Image to Latents",
tags=["latents", "image", "vae", "i2l"],
category="latents",
version="1.1.0",
)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
image: ImageField = InputField(
description="The image to encode",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
orig_dtype = vae.dtype
if upcast:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(orig_dtype)
vae.decoder.conv_in.to(orig_dtype)
vae.decoder.mid_block.to(orig_dtype)
# else:
# latents = latents.float()
else:
vae.to(dtype=torch.float16)
# latents = latents.half()
if tiled:
vae.enable_tiling()
else:
vae.disable_tiling()
tiling_context = nullcontext()
if tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=tile_size,
tile_latent_min_size=tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode(), tiling_context:
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
@staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents: torch.Tensor = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
return latents
@_encode_to_tensor.register
@staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
latents: torch.FloatTensor = vae.encode(image_tensor).latents
return latents

View File

@@ -3,9 +3,7 @@ from typing import Literal, get_args
from PIL import Image from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import ColorField, ImageField
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX from invokeai.app.util.misc import SEED_MAX
@@ -16,6 +14,10 @@ from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch, in
from invokeai.backend.image_util.infill_methods.tile import infill_tile from invokeai.backend.image_util.infill_methods.tile import infill_tile
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
@@ -40,16 +42,15 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Infill the image with the specified method""" """Infill the image with the specified method"""
pass pass
def load_image(self) -> tuple[Image.Image, bool]: def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
"""Process the image to have an alpha channel before being infilled""" """Process the image to have an alpha channel before being infilled"""
image = self._context.images.get_pil(self.image.image_name) image = context.images.get_pil(self.image.image_name)
has_alpha = True if image.mode == "RGBA" else False has_alpha = True if image.mode == "RGBA" else False
return image, has_alpha return image, has_alpha
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
# Retrieve and process image to be infilled # Retrieve and process image to be infilled
input_image, has_alpha = self.load_image() input_image, has_alpha = self.load_image(context)
# If the input image has no alpha channel, return it # If the input image has no alpha channel, return it
if has_alpha is False: if has_alpha is False:
@@ -132,12 +133,8 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model""" """Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image): def infill(self, image: Image.Image):
with self._context.models.load_remote_model( lama = LaMA()
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", return lama(image)
loader=LaMA.load_jit_model,
) as model:
lama = LaMA(model)
return lama(image)
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") @invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")

File diff suppressed because it is too large Load Diff

View File

@@ -1,121 +0,0 @@
from contextlib import nullcontext
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@invocation(
"l2i",
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.3.0",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(latents.dtype)
vae.decoder.conv_in.to(latents.dtype)
vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
else:
vae.to(dtype=torch.float16)
latents = latents.half()
if self.tiled or context.config.get().force_tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
tiling_context = nullcontext()
if self.tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=self.tile_size,
tile_latent_min_size=self.tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode(), tiling_context:
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
TorchDevice.empty_cache()
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@@ -5,11 +5,12 @@ from typing import Literal
import numpy as np import numpy as np
from pydantic import ValidationInfo, field_validator from pydantic import ValidationInfo, field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.fields import FieldDescriptions, InputField
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.1") @invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.1")
class AddInvocation(BaseInvocation): class AddInvocation(BaseInvocation):

View File

@@ -14,7 +14,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
from invokeai.version.invokeai_version import __version__
from ...version import __version__
class MetadataItemField(BaseModel): class MetadataItemField(BaseModel):

View File

@@ -3,17 +3,18 @@ from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
Classification, Classification,
invocation, invocation,
invocation_output, invocation_output,
) )
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
class ModelIdentifierField(BaseModel): class ModelIdentifierField(BaseModel):

View File

@@ -4,12 +4,18 @@
import torch import torch
from pydantic import field_validator from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.util.devices import TorchDevice
from ...backend.util.devices import TorchDevice
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
""" """
Utilities Utilities

View File

@@ -39,11 +39,12 @@ from easing_functions import (
) )
from matplotlib.ticker import MaxNLocator from matplotlib.ticker import MaxNLocator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField
from invokeai.app.invocations.primitives import FloatCollectionOutput from invokeai.app.invocations.primitives import FloatCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField
@invocation( @invocation(
"float_range", "float_range",

View File

@@ -4,7 +4,6 @@ from typing import Optional
import torch import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
ColorField, ColorField,
@@ -22,6 +21,13 @@ from invokeai.app.invocations.fields import (
from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
""" """
Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
- primitive nodes - primitive nodes

View File

@@ -5,11 +5,12 @@ import numpy as np
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
from pydantic import field_validator from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField, UIComponent
from invokeai.app.invocations.primitives import StringCollectionOutput from invokeai.app.invocations.primitives import StringCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, UIComponent
@invocation( @invocation(
"dynamic_prompt", "dynamic_prompt",

View File

@@ -1,103 +0,0 @@
from typing import Literal
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
)
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
@invocation(
"lresize",
title="Resize Latents",
tags=["latents", "resize"],
category="latents",
version="1.0.2",
)
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
width: int = InputField(
ge=64,
multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width,
)
height: int = InputField(
ge=64,
multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width,
)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
device = TorchDevice.choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@invocation(
"lscale",
title="Scale Latents",
tags=["latents", "resize"],
category="latents",
version="1.0.2",
)
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
device = TorchDevice.choose_torch_device()
# resizing
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)

View File

@@ -1,34 +0,0 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import (
FieldDescriptions,
InputField,
OutputField,
UIType,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@invocation_output("scheduler_output")
class SchedulerOutput(BaseInvocationOutput):
scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
@invocation(
"scheduler",
title="Scheduler",
tags=["scheduler"],
category="latents",
version="1.0.0",
)
class SchedulerInvocation(BaseInvocation):
"""Selects a scheduler."""
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
def invoke(self, context: InvocationContext) -> SchedulerOutput:
return SchedulerOutput(scheduler=self.scheduler)

View File

@@ -1,9 +1,15 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType from invokeai.backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .model import CLIPField, ModelIdentifierField, UNetField, VAEField
@invocation_output("sdxl_model_loader_output") @invocation_output("sdxl_model_loader_output")
class SDXLModelLoaderOutput(BaseInvocationOutput): class SDXLModelLoaderOutput(BaseInvocationOutput):

View File

@@ -2,11 +2,17 @@
import re import re
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import InputField, OutputField, UIComponent
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .fields import InputField, OutputField, UIComponent
from .primitives import StringOutput
@invocation_output("string_pos_neg_output") @invocation_output("string_pos_neg_output")
class StringPosNegOutput(BaseInvocationOutput): class StringPosNegOutput(BaseInvocationOutput):

View File

@@ -1,282 +0,0 @@
import copy
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
LatentsField,
UIType,
)
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
MultiDiffusionRegionConditioning,
)
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap,
)
from invokeai.backend.tiles.utils import TBLR
from invokeai.backend.util.devices import TorchDevice
def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
"""Crop a ControlNetData object to a region."""
# Create a shallow copy of the control_data object.
control_data_copy = copy.copy(control_data)
# The ControlNet reference image is the only attribute that needs to be cropped.
control_data_copy.image_tensor = control_data.image_tensor[
:,
:,
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
]
return control_data_copy
@invocation(
"tiled_multi_diffusion_denoise_latents",
title="Tiled Multi-Diffusion Denoise Latents",
tags=["upscale", "denoise"],
category="latents",
classification=Classification.Beta,
version="1.0.0",
)
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
"""Tiled Multi-Diffusion denoising.
This node handles automatically tiling the input image, and is primarily intended for global refinement of images
in tiled upscaling workflows. Future Multi-Diffusion nodes should allow the user to specify custom regions with
different parameters for each region to harness the full power of Multi-Diffusion.
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
T2I-Adapter, masking, etc.).
"""
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
noise: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
)
latents: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
tile_height: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Height of the tiles in image space."
)
tile_width: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Width of the tiles in image space."
)
tile_overlap: int = InputField(
default=32,
multiple_of=LATENT_SCALE_FACTOR,
gt=0,
description="The overlap between adjacent tiles in pixel space. (Of course, tile merging is applied in latent "
"space.) Tiles will be cropped during merging (if necessary) to ensure that they overlap by exactly this "
"amount.",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField(
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
control: ControlField | list[ControlField] | None = InputField(
default=None,
input=Input.Connection,
)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def create_pipeline(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
) -> MultiDiffusionPipeline:
# TODO(ryand): Get rid of this FakeVae hack.
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
return MultiDiffusionPipeline(
vae=FakeVae(),
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# Convert tile image-space dimensions to latent-space dimensions.
latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR
latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR
latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# Calculate the tile locations to cover the latent-space image.
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=latent_tile_height,
tile_width=latent_tile_width,
min_overlap=latent_tile_overlap,
)
# Get the unet's config so that we can pass the base to sd_step_callback().
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = DenoiseLatentsInvocation.prep_control_data(
context=context,
control_input=self.control,
latents_shape=list(latents.shape),
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# Split the controlnet_data into tiles.
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
controlnet_data_tiles: list[list[ControlNetData]] = []
for tile in tiles:
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
controlnet_data_tiles.append(tile_controlnet_data)
# Prepare the MultiDiffusionRegionConditioning list.
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
multi_diffusion_conditioning.append(
MultiDiffusionRegionConditioning(
region=tile,
text_conditioning_data=conditioning_data,
control_data=tile_controlnet_data,
)
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
# Run Multi-Diffusion denoising.
result_latents = pipeline.multi_diffusion_denoise(
multi_diffusion_conditioning=multi_diffusion_conditioning,
target_overlap=latent_tile_overlap,
latents=latents,
scheduler_step_kwargs=scheduler_step_kwargs,
noise=noise,
timesteps=timesteps,
init_timestep=init_timestep,
callback=step_callback,
)
result_latents = result_latents.to("cpu")
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

View File

@@ -1,4 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path
from typing import Literal from typing import Literal
import cv2 import cv2
@@ -6,12 +7,16 @@ import numpy as np
from PIL import Image from PIL import Image
from pydantic import ConfigDict from pydantic import ConfigDict
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
# TODO: Populate this from disk? # TODO: Populate this from disk?
# TODO: Use model manager to load? # TODO: Use model manager to load?
@@ -47,6 +52,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
rrdbnet_model = None rrdbnet_model = None
netscale = None netscale = None
esrgan_model_path = None
if self.model_name in [ if self.model_name in [
"RealESRGAN_x4plus.pth", "RealESRGAN_x4plus.pth",
@@ -89,25 +95,28 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
context.logger.error(msg) context.logger.error(msg)
raise ValueError(msg) raise ValueError(msg)
loadnet = context.models.load_remote_model( esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
source=ESRGAN_MODEL_URLS[self.model_name],
# Downloads the ESRGAN model if it doesn't already exist
download_with_progress_bar(
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
) )
with loadnet as loadnet_model: upscaler = RealESRGAN(
upscaler = RealESRGAN( scale=netscale,
scale=netscale, model_path=esrgan_model_path,
loadnet=loadnet_model, model=rrdbnet_model,
model=rrdbnet_model, half=False,
half=False, tile=self.tile_size,
tile=self.tile_size, )
)
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay? # TODO: This strips the alpha... is that okay?
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
upscaled_image = upscaler.upscale(cv2_image) upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") TorchDevice.empty_cache()
image_dto = context.images.save(image=pil_image) image_dto = context.images.save(image=pil_image)

View File

@@ -2,11 +2,12 @@ import sqlite3
import threading import threading
from typing import Optional, cast from typing import Optional, cast
from invokeai.app.services.board_image_records.board_image_records_base import BoardImageRecordStorageBase
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .board_image_records_base import BoardImageRecordStorageBase
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_conn: sqlite3.Connection _conn: sqlite3.Connection

View File

@@ -1,8 +1,9 @@
from typing import Optional from typing import Optional
from invokeai.app.services.board_images.board_images_base import BoardImagesServiceABC
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from .board_images_base import BoardImagesServiceABC
class BoardImagesService(BoardImagesServiceABC): class BoardImagesService(BoardImagesServiceABC):
__invoker: Invoker __invoker: Invoker

View File

@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecord
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from .board_records_common import BoardChanges, BoardRecord
class BoardRecordStorageBase(ABC): class BoardRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the board record store.""" """Low-level service responsible for interfacing with the board record store."""
@@ -39,12 +40,16 @@ class BoardRecordStorageBase(ABC):
@abstractmethod @abstractmethod
def get_many( def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]: ) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records.""" """Gets many board records."""
pass pass
@abstractmethod @abstractmethod
def get_all(self, include_archived: bool = False) -> list[BoardRecord]: def get_all(
self,
) -> list[BoardRecord]:
"""Gets all board records.""" """Gets all board records."""
pass pass

View File

@@ -22,10 +22,6 @@ class BoardRecord(BaseModelExcludeNull):
"""The updated timestamp of the image.""" """The updated timestamp of the image."""
cover_image_name: Optional[str] = Field(default=None, description="The name of the cover image of the board.") cover_image_name: Optional[str] = Field(default=None, description="The name of the cover image of the board.")
"""The name of the cover image of the board.""" """The name of the cover image of the board."""
archived: bool = Field(description="Whether or not the board is archived.")
"""Whether or not the board is archived."""
is_private: Optional[bool] = Field(default=None, description="Whether the board is private.")
"""Whether the board is private."""
def deserialize_board_record(board_dict: dict) -> BoardRecord: def deserialize_board_record(board_dict: dict) -> BoardRecord:
@@ -39,8 +35,6 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
created_at = board_dict.get("created_at", get_iso_timestamp()) created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp()) updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp()) deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
archived = board_dict.get("archived", False)
is_private = board_dict.get("is_private", False)
return BoardRecord( return BoardRecord(
board_id=board_id, board_id=board_id,
@@ -49,15 +43,12 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
created_at=created_at, created_at=created_at,
updated_at=updated_at, updated_at=updated_at,
deleted_at=deleted_at, deleted_at=deleted_at,
archived=archived,
is_private=is_private,
) )
class BoardChanges(BaseModel, extra="forbid"): class BoardChanges(BaseModel, extra="forbid"):
board_name: Optional[str] = Field(default=None, description="The board's new name.") board_name: Optional[str] = Field(default=None, description="The board's new name.")
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.") cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
class BoardRecordNotFoundException(Exception): class BoardRecordNotFoundException(Exception):

View File

@@ -2,8 +2,12 @@ import sqlite3
import threading import threading
from typing import Union, cast from typing import Union, cast
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.board_records.board_records_common import ( from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string
from .board_records_base import BoardRecordStorageBase
from .board_records_common import (
BoardChanges, BoardChanges,
BoardRecord, BoardRecord,
BoardRecordDeleteException, BoardRecordDeleteException,
@@ -11,9 +15,6 @@ from invokeai.app.services.board_records.board_records_common import (
BoardRecordSaveException, BoardRecordSaveException,
deserialize_board_record, deserialize_board_record,
) )
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string
class SqliteBoardRecordStorage(BoardRecordStorageBase): class SqliteBoardRecordStorage(BoardRecordStorageBase):
@@ -124,17 +125,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(changes.cover_image_name, board_id), (changes.cover_image_name, board_id),
) )
# Change the archived status of a board
if changes.archived is not None:
self._cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
@@ -144,49 +134,35 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
return self.get(board_id) return self.get(board_id)
def get_many( def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]: ) -> OffsetPaginatedResults[BoardRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
# Build base query # Get all the boards
base_query = """ self._cursor.execute(
"""--sql
SELECT * SELECT *
FROM boards FROM boards
{archived_filter}
ORDER BY created_at DESC ORDER BY created_at DESC
LIMIT ? OFFSET ?; LIMIT ? OFFSET ?;
""" """,
(limit, offset),
# Determine archived filter condition )
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result] boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query # Get the total number of boards
if include_archived: self._cursor.execute(
count_query = """ """--sql
SELECT COUNT(*) SELECT COUNT(*)
FROM boards; FROM boards
WHERE 1=1;
""" """
else: )
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
self._cursor.execute(count_query)
count = cast(int, self._cursor.fetchone()[0]) count = cast(int, self._cursor.fetchone()[0])
@@ -198,25 +174,20 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
finally: finally:
self._lock.release() self._lock.release()
def get_all(self, include_archived: bool = False) -> list[BoardRecord]: def get_all(
self,
) -> list[BoardRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
base_query = """ # Get all the boards
self._cursor.execute(
"""--sql
SELECT * SELECT *
FROM boards FROM boards
{archived_filter}
ORDER BY created_at DESC ORDER BY created_at DESC
""" """
)
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
self._cursor.execute(final_query)
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result] boards = [deserialize_board_record(dict(r)) for r in result]

View File

@@ -1,9 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from invokeai.app.services.board_records.board_records_common import BoardChanges from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from .boards_common import BoardDTO
class BoardServiceABC(ABC): class BoardServiceABC(ABC):
"""High-level service for board management.""" """High-level service for board management."""
@@ -43,12 +44,16 @@ class BoardServiceABC(ABC):
@abstractmethod @abstractmethod
def get_many( def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardDTO]: ) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards.""" """Gets many boards."""
pass pass
@abstractmethod @abstractmethod
def get_all(self, include_archived: bool = False) -> list[BoardDTO]: def get_all(
self,
) -> list[BoardDTO]:
"""Gets all boards.""" """Gets all boards."""
pass pass

View File

@@ -2,7 +2,7 @@ from typing import Optional
from pydantic import Field from pydantic import Field
from invokeai.app.services.board_records.board_records_common import BoardRecord from ..board_records.board_records_common import BoardRecord
class BoardDTO(BoardRecord): class BoardDTO(BoardRecord):

View File

@@ -1,9 +1,11 @@
from invokeai.app.services.board_records.board_records_common import BoardChanges from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.boards.boards_base import BoardServiceABC from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.boards.boards_common import BoardDTO, board_record_to_dto
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from .boards_base import BoardServiceABC
from .boards_common import board_record_to_dto
class BoardService(BoardServiceABC): class BoardService(BoardServiceABC):
__invoker: Invoker __invoker: Invoker
@@ -46,10 +48,8 @@ class BoardService(BoardServiceABC):
def delete(self, board_id: str) -> None: def delete(self, board_id: str) -> None:
self.__invoker.services.board_records.delete(board_id) self.__invoker.services.board_records.delete(board_id)
def get_many( def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
self, offset: int = 0, limit: int = 10, include_archived: bool = False board_records = self.__invoker.services.board_records.get_many(offset, limit)
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
board_dtos = [] board_dtos = []
for r in board_records.items: for r in board_records.items:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id) cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
@@ -63,8 +63,8 @@ class BoardService(BoardServiceABC):
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)) return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(self, include_archived: bool = False) -> list[BoardDTO]: def get_all(self) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(include_archived) board_records = self.__invoker.services.board_records.get_all()
board_dtos = [] board_dtos = []
for r in board_records: for r in board_records:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id) cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)

View File

@@ -4,7 +4,6 @@ from typing import Optional, Union
from zipfile import ZipFile from zipfile import ZipFile
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
from invokeai.app.services.bulk_download.bulk_download_common import ( from invokeai.app.services.bulk_download.bulk_download_common import (
DEFAULT_BULK_DOWNLOAD_ID, DEFAULT_BULK_DOWNLOAD_ID,
BulkDownloadException, BulkDownloadException,
@@ -16,6 +15,8 @@ from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase): class BulkDownloadService(BulkDownloadBase):
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:

View File

@@ -1,6 +1,7 @@
"""Init file for InvokeAI configure package.""" """Init file for InvokeAI configure package."""
from invokeai.app.services.config.config_common import PagingArgumentParser from invokeai.app.services.config.config_common import PagingArgumentParser
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config
from .config_default import InvokeAIAppConfig, get_config
__all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"] __all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"]

View File

@@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import copy
import locale import locale
import os import os
import re import re
@@ -26,13 +25,14 @@ DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init") LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0 DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25 DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"] DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"] PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.2" CONFIG_SCHEMA_VERSION = "4.0.1"
def get_default_ram_cache_size() -> float: def get_default_ram_cache_size() -> float:
@@ -85,8 +85,7 @@ class InvokeAIAppConfig(BaseSettings):
log_tokenization: Enable logging of parsed prompt tokens. log_tokenization: Enable logging of parsed prompt tokens.
patchmatch: Enable patchmatch inpaint code. patchmatch: Enable patchmatch inpaint code.
models_dir: Path to the models directory. models_dir: Path to the models directory.
convert_cache_dir: Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions). convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
download_cache_dir: Path to the directory that contains dynamically downloaded models.
legacy_conf_dir: Path to directory of legacy checkpoint config files. legacy_conf_dir: Path to directory of legacy checkpoint config files.
db_dir: Path to InvokeAI databases directory. db_dir: Path to InvokeAI databases directory.
outputs_dir: Path to directory for outputs. outputs_dir: Path to directory for outputs.
@@ -102,6 +101,7 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path to profiles output directory. profiles_dir: Path to profiles output directory.
ram: Maximum memory amount used by memory model cache for rapid switching (GB). ram: Maximum memory amount used by memory model cache for rapid switching (GB).
vram: Amount of VRAM reserved for model storage (GB). vram: Amount of VRAM reserved for model storage (GB).
convert_cache: Maximum size of on-disk converted models cache (GB).
lazy_offload: Keep models in VRAM until their space is needed. lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour. log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
@@ -112,7 +112,6 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue. max_queue_size: Maximum number of items in the session queue.
clear_queue_on_startup: Empties session queue on startup.
allow_nodes: List of nodes to allow. Omit to allow all. allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none. deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory. node_cache_size: How many cached nodes to keep in memory.
@@ -147,8 +146,7 @@ class InvokeAIAppConfig(BaseSettings):
# PATHS # PATHS
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.") models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).") convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.") legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.") db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.") outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
@@ -169,8 +167,9 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.") profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
# CACHE # CACHE
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).") ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).") vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.") lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.") log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
@@ -185,7 +184,6 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
# NODES # NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.") allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
@@ -305,11 +303,6 @@ class InvokeAIAppConfig(BaseSettings):
"""Path to the converted cache models directory, resolved to an absolute path..""" """Path to the converted cache models directory, resolved to an absolute path.."""
return self._resolve(self.convert_cache_dir) return self._resolve(self.convert_cache_dir)
@property
def download_cache_path(self) -> Path:
"""Path to the downloaded models directory, resolved to an absolute path.."""
return self._resolve(self.download_cache_dir)
@property @property
def custom_nodes_path(self) -> Path: def custom_nodes_path(self) -> Path:
"""Path to the custom nodes directory, resolved to an absolute path..""" """Path to the custom nodes directory, resolved to an absolute path.."""
@@ -355,14 +348,14 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
return (init_settings,) return (init_settings,)
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]: def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate a v3 config dictionary to a v4.0.0. """Migrate a v3 config dictionary to a current config object.
Args: Args:
config_dict: A dictionary of settings from a v3 config file. config_dict: A dictionary of settings from a v3 config file.
Returns: Returns:
An `InvokeAIAppConfig` config dict. An instance of `InvokeAIAppConfig` with the migrated settings.
""" """
parsed_config_dict: dict[str, Any] = {} parsed_config_dict: dict[str, Any] = {}
@@ -396,41 +389,32 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
elif k in InvokeAIAppConfig.model_fields: elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields # skip unknown fields
parsed_config_dict[k] = v parsed_config_dict[k] = v
parsed_config_dict["schema_version"] = "4.0.0" # When migrating the config file, we should not include currently-set environment variables.
return parsed_config_dict config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def migrate_v4_0_0_to_4_0_1_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]: def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a v4.0.1 config dictionary """Migrate v4.0.0 config dictionary to a current config object.
Args: Args:
config_dict: A dictionary of settings from a v4.0.0 config file. config_dict: A dictionary of settings from a v4.0.0 config file.
Returns: Returns:
A config dict with the settings migrated to v4.0.1. An instance of `InvokeAIAppConfig` with the migrated settings.
""" """
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict) parsed_config_dict: dict[str, Any] = {}
# precision "autocast" was replaced by "auto" in v4.0.1 for k, v in config_dict.items():
if parsed_config_dict.get("precision") == "autocast": # autocast was removed from precision in v4.0.1
parsed_config_dict["precision"] = "auto" if k == "precision" and v == "autocast":
parsed_config_dict["schema_version"] = "4.0.1" parsed_config_dict["precision"] = "auto"
return parsed_config_dict else:
parsed_config_dict[k] = v
if k == "schema_version":
def migrate_v4_0_1_to_4_0_2_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]: parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
"""Migrate v4.0.1 config dictionary to a v4.0.2 config dictionary. config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
Args:
config_dict: A dictionary of settings from a v4.0.1 config file.
Returns:
An config dict with the settings migrated to v4.0.2.
"""
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
# convert_cache was removed in 4.0.2
parsed_config_dict.pop("convert_cache", None)
parsed_config_dict["schema_version"] = "4.0.2"
return parsed_config_dict
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
@@ -444,31 +428,27 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
""" """
assert config_path.suffix == ".yaml" assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file: with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict: dict[str, Any] = yaml.safe_load(file) loaded_config_dict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict) assert isinstance(loaded_config_dict, dict)
migrated = False
if "InvokeAI" in loaded_config_dict: if "InvokeAI" in loaded_config_dict:
migrated = True # This is a v3 config file, attempt to migrate it
loaded_config_dict = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
if loaded_config_dict["schema_version"] == "4.0.0":
migrated = True
loaded_config_dict = migrate_v4_0_0_to_4_0_1_config_dict(loaded_config_dict)
if loaded_config_dict["schema_version"] == "4.0.1":
migrated = True
loaded_config_dict = migrate_v4_0_1_to_4_0_2_config_dict(loaded_config_dict)
if migrated:
shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try: try:
# load and write without environment variables # loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config = DefaultInvokeAIAppConfig.model_validate(loaded_config_dict) migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
migrated_config.write_file(config_path)
except Exception as e: except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path) shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
# Attempt to load as a v4 config file
try: try:
# Meta is not included in the model fields, so we need to validate it separately # Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict) config = InvokeAIAppConfig.model_validate(loaded_config_dict)

View File

@@ -1,17 +1,10 @@
"""Init file for download queue.""" """Init file for download queue."""
from invokeai.app.services.download.download_base import ( from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
DownloadJob, from .download_default import DownloadQueueService, TqdmProgress
DownloadJobStatus,
DownloadQueueServiceBase,
MultiFileDownloadJob,
UnknownJobIDException,
)
from invokeai.app.services.download.download_default import DownloadQueueService, TqdmProgress
__all__ = [ __all__ = [
"DownloadJob", "DownloadJob",
"MultiFileDownloadJob",
"DownloadQueueServiceBase", "DownloadQueueServiceBase",
"DownloadQueueService", "DownloadQueueService",
"TqdmProgress", "TqdmProgress",

View File

@@ -5,13 +5,11 @@ from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from functools import total_ordering from functools import total_ordering
from pathlib import Path from pathlib import Path
from typing import Any, Callable, List, Optional, Set, Union from typing import Any, Callable, List, Optional
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from invokeai.backend.model_manager.metadata import RemoteModelFile
class DownloadJobStatus(str, Enum): class DownloadJobStatus(str, Enum):
"""State of a download job.""" """State of a download job."""
@@ -35,23 +33,30 @@ class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started.""" """This exception is raised when user attempts to initiate a download before the service is started."""
SingleFileDownloadEventHandler = Callable[["DownloadJob"], None] DownloadEventHandler = Callable[["DownloadJob"], None]
SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None] DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
class DownloadJobBase(BaseModel): @total_ordering
"""Base of classes to monitor and control downloads.""" class DownloadJob(BaseModel):
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation # automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path") # set internally during download process
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download") status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
bytes: int = Field(default=0, description="Bytes downloaded so far") bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total file size (bytes)") total_bytes: int = Field(default=0, description="Total file size (bytes)")
@@ -69,6 +74,14 @@ class DownloadJobBase(BaseModel):
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None) _on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None) _on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
def cancel(self) -> None: def cancel(self) -> None:
"""Call to cancel the job.""" """Call to cancel the job."""
self._cancelled = True self._cancelled = True
@@ -85,11 +98,6 @@ class DownloadJobBase(BaseModel):
"""Return true if job completed without errors.""" """Return true if job completed without errors."""
return self.status == DownloadJobStatus.COMPLETED return self.status == DownloadJobStatus.COMPLETED
@property
def waiting(self) -> bool:
"""Return true if the job is waiting to run."""
return self.status == DownloadJobStatus.WAITING
@property @property
def running(self) -> bool: def running(self) -> bool:
"""Return true if the job is running.""" """Return true if the job is running."""
@@ -146,37 +154,6 @@ class DownloadJobBase(BaseModel):
self._on_cancelled = on_cancelled self._on_cancelled = on_cancelled
@total_ordering
class DownloadJob(DownloadJobBase):
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
class MultiFileDownloadJob(DownloadJobBase):
"""Class to monitor and control multifile downloads."""
download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
class DownloadQueueServiceBase(ABC): class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL.""" """Multithreaded queue for downloading models via URL."""
@@ -224,48 +201,6 @@ class DownloadQueueServiceBase(ABC):
""" """
pass pass
@abstractmethod
def multifile_download(
self,
parts: List[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
"""
Create and enqueue a multifile download job.
:param parts: Set of URL / filename pairs
:param dest: Path to download to. See below.
:param access_token: Access token to download the indicated files. If not provided,
each file's URL may be matched to an access token using the config file matching
system.
:param submit_job: If true [default] then submit the job for execution. Otherwise,
you will need to pass the job to submit_multifile_download().
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
The `dest` argument is a Path object pointing to a directory. All downloads
with be placed inside this directory. The callbacks will receive the
MultiFileDownloadJob.
"""
pass
@abstractmethod
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
"""
Enqueue a previously-created multi-file download job.
:param job: A MultiFileDownloadJob created with multifile_download()
"""
pass
@abstractmethod @abstractmethod
def submit_download_job( def submit_download_job(
self, self,
@@ -317,7 +252,7 @@ class DownloadQueueServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def cancel_job(self, job: DownloadJobBase) -> None: def cancel_job(self, job: DownloadJob) -> None:
"""Cancel the job, clearing partial downloads and putting it into ERROR state.""" """Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass pass
@@ -327,7 +262,7 @@ class DownloadQueueServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase: def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
"""Wait until the indicated download job has reached a terminal state. """Wait until the indicated download job has reached a terminal state.
This will block until the indicated install job has completed, This will block until the indicated install job has completed,

View File

@@ -8,30 +8,29 @@ import time
import traceback import traceback
from pathlib import Path from pathlib import Path
from queue import Empty, PriorityQueue from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Literal, Optional, Set from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
import requests import requests
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests import HTTPError from requests import HTTPError
from tqdm import tqdm from tqdm import tqdm
from invokeai.app.services.config import InvokeAIAppConfig, get_config from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.services.download.download_base import ( from invokeai.backend.util.logging import InvokeAILogger
from .download_base import (
DownloadEventHandler, DownloadEventHandler,
DownloadExceptionHandler, DownloadExceptionHandler,
DownloadJob, DownloadJob,
DownloadJobBase,
DownloadJobCancelledException, DownloadJobCancelledException,
DownloadJobStatus, DownloadJobStatus,
DownloadQueueServiceBase, DownloadQueueServiceBase,
MultiFileDownloadJob,
ServiceInactiveException, ServiceInactiveException,
UnknownJobIDException, UnknownJobIDException,
) )
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp if TYPE_CHECKING:
from invokeai.backend.model_manager.metadata import RemoteModelFile from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.backend.util.logging import InvokeAILogger
# Maximum number of bytes to download during each call to requests.iter_content() # Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000 DOWNLOAD_CHUNK_SIZE = 100000
@@ -43,24 +42,20 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__( def __init__(
self, self,
max_parallel_dl: int = 5, max_parallel_dl: int = 5,
app_config: Optional[InvokeAIAppConfig] = None,
event_bus: Optional["EventServiceBase"] = None, event_bus: Optional["EventServiceBase"] = None,
requests_session: Optional[requests.sessions.Session] = None, requests_session: Optional[requests.sessions.Session] = None,
): ):
""" """
Initialize DownloadQueue. Initialize DownloadQueue.
:param app_config: InvokeAIAppConfig object
:param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests. :param requests_session: Optional requests.sessions.Session object, for unit tests.
""" """
self._app_config = app_config or get_config()
self._jobs: Dict[int, DownloadJob] = {} self._jobs: Dict[int, DownloadJob] = {}
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
self._next_job_id = 0 self._next_job_id = 0
self._queue: PriorityQueue[DownloadJob] = PriorityQueue() self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
self._stop_event = threading.Event() self._stop_event = threading.Event()
self._job_terminated_event = threading.Event() self._job_completed_event = threading.Event()
self._worker_pool: Set[threading.Thread] = set() self._worker_pool: Set[threading.Thread] = set()
self._lock = threading.Lock() self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService") self._logger = InvokeAILogger.get_logger("DownloadQueueService")
@@ -112,16 +107,18 @@ class DownloadQueueService(DownloadQueueServiceBase):
raise ServiceInactiveException( raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service." "The download service is not currently accepting requests. Please call start() to initialize the service."
) )
job.id = self._next_id() with self._lock:
job.set_callbacks( job.id = self._next_job_id
on_start=on_start, self._next_job_id += 1
on_progress=on_progress, job.set_callbacks(
on_complete=on_complete, on_start=on_start,
on_cancelled=on_cancelled, on_progress=on_progress,
on_error=on_error, on_complete=on_complete,
) on_cancelled=on_cancelled,
self._jobs[job.id] = job on_error=on_error,
self._queue.put(job) )
self._jobs[job.id] = job
self._queue.put(job)
def download( def download(
self, self,
@@ -144,7 +141,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
source=source, source=source,
dest=dest, dest=dest,
priority=priority, priority=priority,
access_token=access_token or self._lookup_access_token(source), access_token=access_token,
) )
self.submit_download_job( self.submit_download_job(
job, job,
@@ -156,63 +153,10 @@ class DownloadQueueService(DownloadQueueServiceBase):
) )
return job return job
def multifile_download(
self,
parts: List[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
mfdj.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
for part in parts:
url = part.url
path = dest / part.path
assert path.is_relative_to(dest), "only relative download paths accepted"
job = DownloadJob(
source=url,
dest=path,
access_token=access_token or self._lookup_access_token(url),
)
mfdj.download_parts.add(job)
self._download_part2parent[job.source] = mfdj
if submit_job:
self.submit_multifile_download(mfdj)
return mfdj
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
for download_job in job.download_parts:
self.submit_download_job(
download_job,
on_start=self._mfd_started,
on_progress=self._mfd_progress,
on_complete=self._mfd_complete,
on_cancelled=self._mfd_cancelled,
on_error=self._mfd_error,
)
def join(self) -> None: def join(self) -> None:
"""Wait for all jobs to complete.""" """Wait for all jobs to complete."""
self._queue.join() self._queue.join()
def _next_id(self) -> int:
with self._lock:
id = self._next_job_id
self._next_job_id += 1
return id
def list_jobs(self) -> List[DownloadJob]: def list_jobs(self) -> List[DownloadJob]:
"""List all the jobs.""" """List all the jobs."""
return list(self._jobs.values()) return list(self._jobs.values())
@@ -234,14 +178,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
except KeyError as excp: except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJobBase) -> None: def cancel_job(self, job: DownloadJob) -> None:
""" """
Cancel the indicated job. Cancel the indicated job.
If it is running it will be stopped. If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED job.status will be set to DownloadJobStatus.CANCELLED
""" """
if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]: with self._lock:
job.cancel() job.cancel()
def cancel_all_jobs(self) -> None: def cancel_all_jobs(self) -> None:
@@ -250,12 +194,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
if not job.in_terminal_state: if not job.in_terminal_state:
self.cancel_job(job) self.cancel_job(job)
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase: def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
"""Block until the indicated job has reached terminal state, or when timeout limit reached.""" """Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time() start = time.time()
while not job.in_terminal_state: while not job.in_terminal_state:
if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
self._job_terminated_event.clear() self._job_completed_event.clear()
if timeout > 0 and time.time() - start > timeout: if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded") raise TimeoutError("Timeout exceeded")
return job return job
@@ -284,25 +228,22 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.job_started = get_iso_timestamp() job.job_started = get_iso_timestamp()
self._do_download(job) self._do_download(job)
self._signal_job_complete(job) self._signal_job_complete(job)
except DownloadJobCancelledException: except (OSError, HTTPError) as excp:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
except Exception as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})" job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc() job.error = traceback.format_exc()
self._signal_job_error(job, excp) self._signal_job_error(job, excp)
except DownloadJobCancelledException:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
finally: finally:
job.job_ended = get_iso_timestamp() job.job_ended = get_iso_timestamp()
self._job_terminated_event.set() # signal a change to terminal state self._job_completed_event.set() # signal a change to terminal state
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
self._job_terminated_event.set()
self._queue.task_done() self._queue.task_done()
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
def _do_download(self, job: DownloadJob) -> None: def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download.""" """Do the actual download."""
url = job.source url = job.source
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb" open_mode = "wb"
@@ -394,29 +335,38 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _in_progress_path(self, path: Path) -> Path: def _in_progress_path(self, path: Path) -> Path:
return path.with_name(path.name + ".downloading") return path.with_name(path.name + ".downloading")
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
# Pull the token from config if it exists and matches the URL
token = None
for pair in self._app_config.remote_api_tokens or []:
if re.search(pair.url_regex, str(source)):
token = pair.token
break
return token
def _signal_job_started(self, job: DownloadJob) -> None: def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING job.status = DownloadJobStatus.RUNNING
self._execute_cb(job, "on_start") if job.on_start:
try:
job.on_start(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
)
if self._event_bus: if self._event_bus:
self._event_bus.emit_download_started(job) self._event_bus.emit_download_started(job)
def _signal_job_progress(self, job: DownloadJob) -> None: def _signal_job_progress(self, job: DownloadJob) -> None:
self._execute_cb(job, "on_progress") if job.on_progress:
try:
job.on_progress(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
)
if self._event_bus: if self._event_bus:
self._event_bus.emit_download_progress(job) self._event_bus.emit_download_progress(job)
def _signal_job_complete(self, job: DownloadJob) -> None: def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED job.status = DownloadJobStatus.COMPLETED
self._execute_cb(job, "on_complete") if job.on_complete:
try:
job.on_complete(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
)
if self._event_bus: if self._event_bus:
self._event_bus.emit_download_complete(job) self._event_bus.emit_download_complete(job)
@@ -424,21 +374,26 @@ class DownloadQueueService(DownloadQueueServiceBase):
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]: if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
return return
job.status = DownloadJobStatus.CANCELLED job.status = DownloadJobStatus.CANCELLED
self._execute_cb(job, "on_cancelled") if job.on_cancelled:
try:
job.on_cancelled(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
)
if self._event_bus: if self._event_bus:
self._event_bus.emit_download_cancelled(job) self._event_bus.emit_download_cancelled(job)
# if multifile download, then signal the parent
if parent_job := self._download_part2parent.get(job.source, None):
if not parent_job.in_terminal_state:
parent_job.status = DownloadJobStatus.CANCELLED
self._execute_cb(parent_job, "on_cancelled")
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None: def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR job.status = DownloadJobStatus.ERROR
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}") self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
self._execute_cb(job, "on_error", excp) if job.on_error:
try:
job.on_error(job, excp)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
)
if self._event_bus: if self._event_bus:
self._event_bus.emit_download_error(job) self._event_bus.emit_download_error(job)
@@ -451,97 +406,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
except OSError as excp: except OSError as excp:
self._logger.warning(excp) self._logger.warning(excp)
########################################
# callbacks used for multifile downloads
########################################
def _mfd_started(self, download_job: DownloadJob) -> None:
self._logger.info(f"File download started: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
if mf_job.waiting:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.status = DownloadJobStatus.RUNNING
assert download_job.download_path is not None
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
mf_job.download_path = (
mf_job.dest / path_relative_to_destdir.parts[0]
) # keep just the first component of the path
self._execute_cb(mf_job, "on_start")
def _mfd_progress(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
if mf_job.cancelled:
for part in mf_job.download_parts:
self.cancel_job(part)
elif mf_job.running:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
self._execute_cb(mf_job, "on_progress")
def _mfd_complete(self, download_job: DownloadJob) -> None:
self._logger.info(f"Download complete: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
# are there any more active jobs left in this task?
if mf_job.running and all(x.complete for x in mf_job.download_parts):
mf_job.status = DownloadJobStatus.COMPLETED
self._execute_cb(mf_job, "on_complete")
# we're done with this sub-job
self._job_terminated_event.set()
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
assert mf_job is not None
if not mf_job.in_terminal_state:
self._logger.warning(f"Download cancelled: {download_job.source}")
mf_job.cancel()
for s in mf_job.download_parts:
self.cancel_job(s)
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
assert mf_job is not None
if not mf_job.in_terminal_state:
mf_job.status = download_job.status
mf_job.error = download_job.error
mf_job.error_type = download_job.error_type
self._execute_cb(mf_job, "on_error", excp)
self._logger.error(
f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
)
for s in [x for x in mf_job.download_parts if x.running]:
self.cancel_job(s)
self._download_part2parent.pop(download_job.source)
self._job_terminated_event.set()
def _execute_cb(
self,
job: DownloadJob | MultiFileDownloadJob,
callback_name: Literal[
"on_start",
"on_progress",
"on_complete",
"on_cancelled",
"on_error",
],
excp: Optional[Exception] = None,
) -> None:
if callback := getattr(job, callback_name, None):
args = [job, excp] if excp else [job]
try:
callback(*args)
except Exception as e:
self._logger.error(
f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
)
def get_pc_name_max(directory: str) -> int: def get_pc_name_max(directory: str) -> int:
if hasattr(os, "pathconf"): if hasattr(os, "pathconf"):

View File

@@ -22,7 +22,6 @@ from invokeai.app.services.events.events_common import (
ModelInstallCompleteEvent, ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent, ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent, ModelInstallDownloadsCompleteEvent,
ModelInstallDownloadStartedEvent,
ModelInstallErrorEvent, ModelInstallErrorEvent,
ModelInstallStartedEvent, ModelInstallStartedEvent,
ModelLoadCompleteEvent, ModelLoadCompleteEvent,
@@ -35,6 +34,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineInterme
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
@@ -145,10 +145,6 @@ class EventServiceBase:
# region Model install # region Model install
def emit_model_install_download_started(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is started (remote models only)."""
self.dispatch(ModelInstallDownloadStartedEvent.build(job))
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None: def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is in progress (remote models only).""" """Emitted at intervals while the install job is in progress (remote models only)."""
self.dispatch(ModelInstallDownloadProgressEvent.build(job)) self.dispatch(ModelInstallDownloadProgressEvent.build(job))

View File

@@ -3,8 +3,9 @@ from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, P
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS, QUEUE_ITEM_STATUS,
@@ -13,7 +14,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItem, SessionQueueItem,
SessionQueueStatus, SessionQueueStatus,
) )
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -98,9 +98,17 @@ class InvocationEventBase(QueueItemEventBase):
item_id: int = Field(description="The ID of the queue item") item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch") batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)") session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: AnyInvocation = Field(description="The ID of the invocation") invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@field_validator("invocation", mode="plain")
@classmethod
def validate_invocation(cls, v: Any):
"""Validates the invocation using the dynamic type adapter."""
invocation = BaseInvocation.get_typeadapter().validate_python(v)
return invocation
@payload_schema.register @payload_schema.register
class InvocationStartedEvent(InvocationEventBase): class InvocationStartedEvent(InvocationEventBase):
@@ -109,7 +117,7 @@ class InvocationStartedEvent(InvocationEventBase):
__event_name__ = "invocation_started" __event_name__ = "invocation_started"
@classmethod @classmethod
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent": def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
return cls( return cls(
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
item_id=queue_item.item_id, item_id=queue_item.item_id,
@@ -136,7 +144,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
def build( def build(
cls, cls,
queue_item: SessionQueueItem, queue_item: SessionQueueItem,
invocation: AnyInvocation, invocation: BaseInvocation,
intermediate_state: PipelineIntermediateState, intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage, progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent": ) -> "InvocationDenoiseProgressEvent":
@@ -174,11 +182,19 @@ class InvocationCompleteEvent(InvocationEventBase):
__event_name__ = "invocation_complete" __event_name__ = "invocation_complete"
result: AnyInvocationOutput = Field(description="The result of the invocation") result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
@field_validator("result", mode="plain")
@classmethod
def validate_results(cls, v: Any):
"""Validates the invocation result using the dynamic type adapter."""
result = BaseInvocationOutput.get_typeadapter().validate_python(v)
return result
@classmethod @classmethod
def build( def build(
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
) -> "InvocationCompleteEvent": ) -> "InvocationCompleteEvent":
return cls( return cls(
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
@@ -207,7 +223,7 @@ class InvocationErrorEvent(InvocationEventBase):
def build( def build(
cls, cls,
queue_item: SessionQueueItem, queue_item: SessionQueueItem,
invocation: AnyInvocation, invocation: BaseInvocation,
error_type: str, error_type: str,
error_message: str, error_message: str,
error_traceback: str, error_traceback: str,
@@ -417,42 +433,6 @@ class ModelLoadCompleteEvent(ModelEventBase):
return cls(config=config, submodel_type=submodel_type) return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelInstallDownloadStartedEvent(ModelEventBase):
"""Event model for model_install_download_started"""
__event_name__ = "model_install_download_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register @payload_schema.register
class ModelInstallDownloadProgressEvent(ModelEventBase): class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress""" """Event model for model_install_download_progress"""

View File

@@ -6,11 +6,12 @@ from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import ( from invokeai.app.services.events.events_common import (
EventBase, EventBase,
) )
from .events_base import EventServiceBase
class FastAPIEventService(EventServiceBase): class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None: def __init__(self, event_handler_id: int) -> None:

View File

@@ -7,15 +7,12 @@ from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.image_files.image_files_common import (
ImageFileDeleteException,
ImageFileNotFoundException,
ImageFileSaveException,
)
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
from .image_files_base import ImageFileStorageBase
from .image_files_common import ImageFileDeleteException, ImageFileNotFoundException, ImageFileSaveException
class DiskImageFileStorage(ImageFileStorageBase): class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk""" """Stores images on disk"""

View File

@@ -3,14 +3,9 @@ from datetime import datetime
from typing import Optional from typing import Optional
from invokeai.app.invocations.fields import MetadataField from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
class ImageRecordStorageBase(ABC): class ImageRecordStorageBase(ABC):
@@ -42,13 +37,10 @@ class ImageRecordStorageBase(ABC):
self, self,
offset: int = 0, offset: int = 0,
limit: int = 10, limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None, image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None, categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None, board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]: ) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records.""" """Gets a page of image records."""
pass pass

View File

@@ -4,8 +4,11 @@ from datetime import datetime
from typing import Optional, Union, cast from typing import Optional, Union, cast
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.image_records.image_records_common import ( from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .image_records_base import ImageRecordStorageBase
from .image_records_common import (
IMAGE_DTO_COLS, IMAGE_DTO_COLS,
ImageCategory, ImageCategory,
ImageRecord, ImageRecord,
@@ -16,9 +19,6 @@ from invokeai.app.services.image_records.image_records_common import (
ResourceOrigin, ResourceOrigin,
deserialize_image_record, deserialize_image_record,
) )
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteImageRecordStorage(ImageRecordStorageBase): class SqliteImageRecordStorage(ImageRecordStorageBase):
@@ -144,13 +144,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self, self,
offset: int = 0, offset: int = 0,
limit: int = 10, limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None, image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None, categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None, board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]: ) -> OffsetPaginatedResults[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
@@ -211,21 +208,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
query_params.append(board_id) query_params.append(board_id)
# Search term condition query_pagination = """--sql
if search_term: ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
query_conditions += """--sql """
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Final images query with pagination # Final images query with pagination
images_query += query_conditions + query_pagination + ";" images_query += query_conditions + query_pagination + ";"

View File

@@ -12,7 +12,6 @@ from invokeai.app.services.image_records.image_records_common import (
) )
from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
class ImageServiceABC(ABC): class ImageServiceABC(ABC):
@@ -117,13 +116,10 @@ class ImageServiceABC(ABC):
self, self,
offset: int = 0, offset: int = 0,
limit: int = 10, limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None, image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None, categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None, board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs.""" """Gets a paginated list of image DTOs."""
pass pass

View File

@@ -3,12 +3,15 @@ from typing import Optional
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_files.image_files_common import ( from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from ..image_files.image_files_common import (
ImageFileDeleteException, ImageFileDeleteException,
ImageFileNotFoundException, ImageFileNotFoundException,
ImageFileSaveException, ImageFileSaveException,
) )
from invokeai.app.services.image_records.image_records_common import ( from ..image_records.image_records_common import (
ImageCategory, ImageCategory,
ImageRecord, ImageRecord,
ImageRecordChanges, ImageRecordChanges,
@@ -19,11 +22,8 @@ from invokeai.app.services.image_records.image_records_common import (
InvalidOriginException, InvalidOriginException,
ResourceOrigin, ResourceOrigin,
) )
from invokeai.app.services.images.images_base import ImageServiceABC from .images_base import ImageServiceABC
from invokeai.app.services.images.images_common import ImageDTO, image_record_to_dto from .images_common import ImageDTO, image_record_to_dto
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
class ImageService(ImageServiceABC): class ImageService(ImageServiceABC):
@@ -73,12 +73,7 @@ class ImageService(ImageServiceABC):
session_id=session_id, session_id=session_id,
) )
if board_id is not None: if board_id is not None:
try: self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self.__invoker.services.board_image_records.add_image_to_board(
board_id=board_id, image_name=image_name
)
except Exception as e:
self.__invoker.services.logger.warn(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.image_files.save( self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
) )
@@ -207,25 +202,19 @@ class ImageService(ImageServiceABC):
self, self,
offset: int = 0, offset: int = 0,
limit: int = 10, limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None, image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None, categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None, board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
try: try:
results = self.__invoker.services.image_records.get_many( results = self.__invoker.services.image_records.get_many(
offset, offset,
limit, limit,
starred_first,
order_dir,
image_origin, image_origin,
categories, categories,
is_intermediate, is_intermediate,
board_id, board_id,
search_term,
) )
image_dtos = [ image_dtos = [

View File

@@ -10,28 +10,29 @@ if TYPE_CHECKING:
import torch import torch
from invokeai.app.services.board_image_records.board_image_records_base import BoardImageRecordStorageBase
from invokeai.app.services.board_images.board_images_base import BoardImagesServiceABC
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase
from invokeai.app.services.images.images_base import ImageServiceABC
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.names.names_base import NameServiceBase
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.urls.urls_base import UrlServiceBase
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from .board_image_records.board_image_records_base import BoardImageRecordStorageBase
from .board_images.board_images_base import BoardImagesServiceABC
from .board_records.board_records_base import BoardRecordStorageBase
from .boards.boards_base import BoardServiceABC
from .bulk_download.bulk_download_base import BulkDownloadBase
from .config import InvokeAIAppConfig
from .download import DownloadQueueServiceBase
from .events.events_base import EventServiceBase
from .image_files.image_files_base import ImageFileStorageBase
from .image_records.image_records_base import ImageRecordStorageBase
from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .model_images.model_images_base import ModelImageFileStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
from .urls.urls_base import UrlServiceBase
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
class InvocationServices: class InvocationServices:
"""Services that can be used by invocations""" """Services that can be used by invocations"""

View File

@@ -9,8 +9,11 @@ import torch
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invocation_stats.invocation_stats_base import InvocationStatsServiceBase from invokeai.app.services.invoker import Invoker
from invokeai.app.services.invocation_stats.invocation_stats_common import ( from invokeai.backend.model_manager.load.model_cache import CacheStats
from .invocation_stats_base import InvocationStatsServiceBase
from .invocation_stats_common import (
GESStatsNotFoundError, GESStatsNotFoundError,
GraphExecutionStats, GraphExecutionStats,
GraphExecutionStatsSummary, GraphExecutionStatsSummary,
@@ -19,8 +22,6 @@ from invokeai.app.services.invocation_stats.invocation_stats_common import (
NodeExecutionStats, NodeExecutionStats,
NodeExecutionStatsSummary, NodeExecutionStatsSummary,
) )
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load.model_cache import CacheStats
# Size of 1GB in bytes. # Size of 1GB in bytes.
GB = 2**30 GB = 2**30

View File

@@ -1,7 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from invokeai.app.services.invocation_services import InvocationServices from .invocation_services import InvocationServices
class Invoker: class Invoker:

View File

@@ -5,14 +5,15 @@ from PIL.Image import Image as PILImageType
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase from invokeai.app.util.misc import uuid_string
from invokeai.app.services.model_images.model_images_common import ( from invokeai.app.util.thumbnails import make_thumbnail
from .model_images_base import ModelImageFileStorageBase
from .model_images_common import (
ModelImageFileDeleteException, ModelImageFileDeleteException,
ModelImageFileNotFoundException, ModelImageFileNotFoundException,
ModelImageFileSaveException, ModelImageFileSaveException,
) )
from invokeai.app.util.misc import uuid_string
from invokeai.app.util.thumbnails import make_thumbnail
class ModelImageFileStorageDisk(ModelImageFileStorageBase): class ModelImageFileStorageDisk(ModelImageFileStorageBase):

View File

@@ -1,7 +1,9 @@
"""Initialization file for model install service package.""" """Initialization file for model install service package."""
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase from .model_install_base import (
from invokeai.app.services.model_install.model_install_common import ( ModelInstallServiceBase,
)
from .model_install_common import (
HFModelSource, HFModelSource,
InstallStatus, InstallStatus,
LocalModelSource, LocalModelSource,
@@ -10,7 +12,7 @@ from invokeai.app.services.model_install.model_install_common import (
UnknownInstallJobException, UnknownInstallJobException,
URLModelSource, URLModelSource,
) )
from invokeai.app.services.model_install.model_install_default import ModelInstallService from .model_install_default import ModelInstallService
__all__ = [ __all__ = [
"ModelInstallServiceBase", "ModelInstallServiceBase",

View File

@@ -13,7 +13,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig from invokeai.backend.model_manager.config import AnyModelConfig
class ModelInstallServiceBase(ABC): class ModelInstallServiceBase(ABC):
@@ -243,11 +243,12 @@ class ModelInstallServiceBase(ABC):
""" """
@abstractmethod @abstractmethod
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path: def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
""" """
Download the model file located at source to the models cache and return its Path. Download the model file located at source to the models cache and return its Path.
:param source: A string representing a URL or repo_id. :param source: A Url or a string that can be converted into one.
:param access_token: Optional access token to access restricted resources.
The model file will be downloaded into the system-wide model cache The model file will be downloaded into the system-wide model cache
(`models/.cache`) if it isn't already there. Note that the model cache (`models/.cache`) if it isn't already there. Note that the model cache

View File

@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob from invokeai.app.services.download import DownloadJob
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
@@ -26,6 +26,13 @@ class InstallStatus(str, Enum):
CANCELLED = "cancelled" # terminated with an error message CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception): class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested.""" """Raised when the status of an unknown job is requested."""
@@ -162,7 +169,6 @@ class ModelInstallJob(BaseModel):
) )
# internal flags and transitory settings # internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None) _install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_multifile_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None) _exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None: def set_error(self, e: Exception) -> None:

View File

@@ -5,34 +5,23 @@ import os
import re import re
import threading import threading
import time import time
from hashlib import sha256
from pathlib import Path from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch import torch
import yaml import yaml
from huggingface_hub import HfFolder from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from pydantic_core import Url
from requests import Session from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_install.model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelSource,
StringLikeSource,
URLModelSource,
)
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
@@ -55,10 +44,23 @@ from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import slugify
from .model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelSource,
StringLikeSource,
URLModelSource,
)
TMPDIR_PREFIX = "tmpinstall_" TMPDIR_PREFIX = "tmpinstall_"
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
class ModelInstallService(ModelInstallServiceBase): class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation.""" """class for InvokeAI model installation."""
@@ -89,7 +91,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._downloads_changed_event = threading.Event() self._downloads_changed_event = threading.Event()
self._install_completed_event = threading.Event() self._install_completed_event = threading.Event()
self._download_queue = download_queue self._download_queue = download_queue
self._download_cache: Dict[int, ModelInstallJob] = {} self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._running = False self._running = False
self._session = session self._session = session
self._install_thread: Optional[threading.Thread] = None self._install_thread: Optional[threading.Thread] = None
@@ -208,12 +210,33 @@ class ModelInstallService(ModelInstallServiceBase):
access_token: Optional[str] = None, access_token: Optional[str] = None,
inplace: Optional[bool] = False, inplace: Optional[bool] = False,
) -> ModelInstallJob: ) -> ModelInstallJob:
"""Install a model using pattern matching to infer the type of source.""" variants = "|".join(ModelRepoVariant.__members__.values())
source_obj = self._guess_source(source) hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
if isinstance(source_obj, LocalModelSource): source_obj: Optional[StringLikeSource] = None
source_obj.inplace = inplace
elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource): if Path(source).exists(): # A local file or directory
source_obj.access_token = access_token source_obj = LocalModelSource(path=Path(source), inplace=inplace)
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
# Pull the token from config if it exists and matches the URL
_token = access_token
if _token is None:
for pair in self.app_config.remote_api_tokens or []:
if re.search(pair.url_regex, source):
_token = pair.token
break
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return self.import_model(source_obj, config) return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
@@ -274,9 +297,8 @@ class ModelInstallService(ModelInstallServiceBase):
def cancel_job(self, job: ModelInstallJob) -> None: def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job.""" """Cancel the indicated job."""
job.cancel() job.cancel()
self._logger.warning(f"Cancelling {job.source}") with self._lock:
if dj := job._multifile_job: self._cancel_download_parts(job)
self._download_queue.cancel_job(dj)
def prune_jobs(self) -> None: def prune_jobs(self) -> None:
"""Prune all completed and errored jobs.""" """Prune all completed and errored jobs."""
@@ -324,7 +346,7 @@ class ModelInstallService(ModelInstallServiceBase):
legacy_config_path = stanza.get("config") legacy_config_path = stanza.get("config")
if legacy_config_path: if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir. # In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path = self._app_config.root_path / legacy_config_path legacy_config_path: Path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path): if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path) legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config["config_path"] = str(legacy_config_path) config["config_path"] = str(legacy_config_path)
@@ -364,95 +386,38 @@ class ModelInstallService(ModelInstallServiceBase):
rmtree(model_path) rmtree(model_path)
self.unregister(key) self.unregister(key)
@classmethod def download_and_cache(
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
escaped_source = slugify(str(source))
return app_config.download_cache_path / escaped_source
def download_and_cache_model(
self, self,
source: str | AnyHttpUrl, source: Union[str, AnyHttpUrl],
access_token: Optional[str] = None,
timeout: int = 0,
) -> Path: ) -> Path:
"""Download the model file located at source to the models cache and return its Path.""" """Download the model file located at source to the models cache and return its Path."""
model_path = self._download_cache_path(str(source), self._app_config) model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
model_path = self._app_config.convert_cache_path / model_hash
# We expect the cache directory to contain one and only one downloaded file or directory. # We expect the cache directory to contain one and only one downloaded file.
# We don't know the file's name in advance, as it is set by the download # We don't know the file's name in advance, as it is set by the download
# content-disposition header. # content-disposition header.
if model_path.exists(): if model_path.exists():
contents: List[Path] = list(model_path.iterdir()) contents = [x for x in model_path.iterdir() if x.is_file()]
if len(contents) > 0: if len(contents) > 0:
return contents[0] return contents[0]
model_path.mkdir(parents=True, exist_ok=True) model_path.mkdir(parents=True, exist_ok=True)
model_source = self._guess_source(str(source)) job = self._download_queue.download(
remote_files, _ = self._remote_files_from_source(model_source) source=AnyHttpUrl(str(source)),
job = self._multifile_download(
dest=model_path, dest=model_path,
remote_files=remote_files, access_token=access_token,
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None, on_progress=TqdmProgress().update,
) )
files_string = "file" if len(remote_files) == 1 else "files" self._download_queue.wait_for_job(job, timeout)
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
self._download_queue.wait_for_job(job)
if job.complete: if job.complete:
assert job.download_path is not None assert job.download_path is not None
return job.download_path return job.download_path
else: else:
raise Exception(job.error) raise Exception(job.error)
def _remote_files_from_source(
self, source: ModelSource
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
metadata = None
if isinstance(source, HFModelSource):
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
return (
metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
),
metadata,
)
if isinstance(source, URLModelSource):
try:
fetcher = self.get_fetcher_from_url(str(source.url))
kwargs: dict[str, Any] = {"session": self._session}
metadata = fetcher(**kwargs).from_url(source.url)
assert isinstance(metadata, ModelMetadataWithFiles)
return metadata.download_urls(session=self._session), metadata
except ValueError:
pass
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
raise Exception(f"No files associated with {source}")
def _guess_source(self, source: str) -> ModelSource:
"""Turn a source string into a ModelSource object."""
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
)
elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource(
url=Url(source),
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return source_obj
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
# Internal functions that manage the installer threads # Internal functions that manage the installer threads
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
@@ -513,19 +478,16 @@ class ModelInstallService(ModelInstallServiceBase):
job.config_out = self.record_store.get_model(key) job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job) self._signal_job_completed(job)
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None: def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
multifile_download_job = install_job._multifile_job if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
if multifile_download_job and any( job.set_error(
x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
):
install_job.set_error(
InvalidModelConfigException( InvalidModelConfigException(
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
) )
) )
else: else:
install_job.set_error(excp) job.set_error(excp)
self._signal_job_errored(install_job) self._signal_job_errored(job)
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory # Internal functions that manage the models directory
@@ -551,6 +513,7 @@ class ModelInstallService(ModelInstallServiceBase):
This is typically only used during testing with a new DB or when using the memory DB, because those are the This is typically only used during testing with a new DB or when using the memory DB, because those are the
only situations in which we may have orphaned models in the models directory. only situations in which we may have orphaned models in the models directory.
""" """
installed_model_paths = { installed_model_paths = {
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models() (self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
} }
@@ -562,13 +525,8 @@ class ModelInstallService(ModelInstallServiceBase):
if resolved_path in installed_model_paths: if resolved_path in installed_model_paths:
return True return True
# Skip core models entirely - these aren't registered with the model manager. # Skip core models entirely - these aren't registered with the model manager.
for special_directory in [ if str(resolved_path).startswith(str(self.app_config.models_path / "core")):
self.app_config.models_path / "core", return False
self.app_config.convert_cache_dir,
self.app_config.download_cache_dir,
]:
if resolved_path.is_relative_to(special_directory):
return False
try: try:
model_id = self.register_path(model_path) model_id = self.register_path(model_path)
self._logger.info(f"Registered {model_path.name} with id {model_id}") self._logger.info(f"Registered {model_path.name} with id {model_id}")
@@ -683,15 +641,20 @@ class ModelInstallService(ModelInstallServiceBase):
inplace=source.inplace or False, inplace=source.inplace or False,
) )
def _import_from_hf( def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
self,
source: HFModelSource,
config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests # Add user's cached access token to HuggingFace requests
if source.access_token is None: source.access_token = source.access_token or HfFolder.get_token()
source.access_token = HfFolder.get_token() if not source.access_token:
remote_files, metadata = self._remote_files_from_source(source) self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
)
return self._import_remote_model( return self._import_remote_model(
source=source, source=source,
config=config, config=config,
@@ -699,12 +662,22 @@ class ModelInstallService(ModelInstallServiceBase):
metadata=metadata, metadata=metadata,
) )
def _import_from_url( def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
self, # URLs from HuggingFace will be handled specially
source: URLModelSource, metadata = None
config: Optional[Dict[str, Any]], fetcher = None
) -> ModelInstallJob: try:
remote_files, metadata = self._remote_files_from_source(source) fetcher = self.get_fetcher_from_url(str(source.url))
except ValueError:
pass
kwargs: dict[str, Any] = {"session": self._session}
if fetcher is not None:
metadata = fetcher(**kwargs).from_url(source.url)
self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
else:
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
return self._import_remote_model( return self._import_remote_model(
source=source, source=source,
config=config, config=config,
@@ -719,9 +692,12 @@ class ModelInstallService(ModelInstallServiceBase):
metadata: Optional[AnyModelRepoMetadata], metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]], config: Optional[Dict[str, Any]],
) -> ModelInstallJob: ) -> ModelInstallJob:
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread.
if len(remote_files) == 0: if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found") raise ValueError(f"{source}: No downloadable files found")
destdir = Path( tmpdir = Path(
mkdtemp( mkdtemp(
dir=self._app_config.models_path, dir=self._app_config.models_path,
prefix=TMPDIR_PREFIX, prefix=TMPDIR_PREFIX,
@@ -732,28 +708,55 @@ class ModelInstallService(ModelInstallServiceBase):
source=source, source=source,
config_in=config or {}, config_in=config or {},
source_metadata=metadata, source_metadata=metadata,
local_path=destdir, # local path may change once the download has started due to content-disposition handling local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
bytes=0, bytes=0,
total_bytes=0, total_bytes=0,
) )
# remember the temporary directory for later removal # In the event that there is a subfolder specified in the source,
install_job._install_tmpdir = destdir # we need to remove it from the destination path in order to avoid
install_job.total_bytes = sum((x.size or 0) for x in remote_files) # creating unwanted subfolders
if isinstance(source, HFModelSource) and source.subfolder:
root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder
else:
root = Path(".")
subfolder = Path(".")
multifile_job = self._multifile_download( # we remember the path up to the top of the tmpdir so that it may be
remote_files=remote_files, # removed safely at the end of the install process.
dest=destdir, install_job._install_tmpdir = tmpdir
subfolder=source.subfolder if isinstance(source, HFModelSource) else None, assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
access_token=source.access_token,
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
)
self._download_cache[multifile_job.id] = install_job
install_job._multifile_job = multifile_job
files_string = "file" if len(remote_files) == 1 else "files" files_string = "file" if len(remote_files) == 1 else "file"
self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})") self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
self._logger.debug(f"remote_files={remote_files}") self._logger.debug(f"remote_files={remote_files}")
self._download_queue.submit_multifile_download(multifile_job) for model_file in remote_files:
url = model_file.url
path = root / model_file.path.relative_to(subfolder)
self._logger.debug(f"Downloading {url} => {path}")
install_job.total_bytes += model_file.size
assert hasattr(source, "access_token")
dest = tmpdir / path.parent
dest.mkdir(parents=True, exist_ok=True)
download_job = DownloadJob(
source=url,
dest=dest,
access_token=source.access_token,
)
self._download_cache[download_job.source] = install_job # matches a download job to an install job
install_job.download_parts.add(download_job)
# only start the jobs once install_job.download_parts is fully populated
for download_job in install_job.download_parts:
self._download_queue.submit_download_job(
download_job,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
return install_job return install_job
def _stat_size(self, path: Path) -> int: def _stat_size(self, path: Path) -> int:
@@ -765,104 +768,87 @@ class ModelInstallService(ModelInstallServiceBase):
size += sum(self._stat_size(Path(root, x)) for x in files) size += sum(self._stat_size(Path(root, x)) for x in files)
return size return size
def _multifile_download(
self,
remote_files: List[RemoteModelFile],
dest: Path,
subfolder: Optional[Path] = None,
access_token: Optional[str] = None,
submit_job: bool = True,
) -> MultiFileDownloadJob:
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
if subfolder:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")
else:
path_to_remove = Path(".")
path_to_add = Path(".")
parts: List[RemoteModelFile] = []
for model_file in remote_files:
assert model_file.size is not None
parts.append(
RemoteModelFile(
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
path=path_to_add / model_file.path.relative_to(path_to_remove),
)
)
return self._download_queue.multifile_download(
parts=parts,
dest=dest,
access_token=access_token,
submit_job=submit_job,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Callbacks are executed by the download queue in a separate thread # Callbacks are executed by the download queue in a separate thread
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None: def _download_started_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"Model download started: {download_job.source}")
with self._lock: with self._lock:
if install_job := self._download_cache.get(download_job.id, None): install_job = self._download_cache[download_job.source]
install_job.status = InstallStatus.DOWNLOADING install_job.status = InstallStatus.DOWNLOADING
if install_job.local_path == install_job._install_tmpdir: # first time assert download_job.download_path
assert download_job.download_path if install_job.local_path == install_job._install_tmpdir:
install_job.local_path = download_job.download_path partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
install_job.download_parts = download_job.download_parts dest_name = partial_path.parts[0]
install_job.bytes = sum(x.bytes for x in download_job.download_parts) install_job.local_path = install_job._install_tmpdir / dest_name
install_job.total_bytes = download_job.total_bytes
self._signal_job_download_started(install_job)
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: # Update the total bytes count for remote sources.
if not install_job.total_bytes:
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
def _download_progress_callback(self, download_job: DownloadJob) -> None:
with self._lock: with self._lock:
if install_job := self._download_cache.get(download_job.id, None): install_job = self._download_cache[download_job.source]
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
self._download_queue.cancel_job(download_job) self._cancel_download_parts(install_job)
else: else:
# update sizes # update sizes
install_job.bytes = sum(x.bytes for x in download_job.download_parts) install_job.bytes = sum(x.bytes for x in install_job.download_parts)
install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts) self._signal_job_downloading(install_job)
self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"Model download complete: {download_job.source}")
with self._lock: with self._lock:
if install_job := self._download_cache.pop(download_job.id, None): install_job = self._download_cache[download_job.source]
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
self._signal_job_downloads_done(install_job) self._signal_job_downloads_done(install_job)
self._put_in_queue(install_job) # this starts the installation and registration self._put_in_queue(install_job)
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._downloads_changed_event.set() self._download_cache.pop(download_job.source, None)
self._downloads_changed_event.set()
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock: with self._lock:
if install_job := self._download_cache.pop(download_job.id, None): install_job = self._download_cache.pop(download_job.source, None)
assert excp is not None assert install_job is not None
self._set_error(install_job, excp) assert excp is not None
self._download_queue.cancel_job(download_job) install_job.set_error(excp)
self._logger.error(
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
)
self._cancel_download_parts(install_job)
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._downloads_changed_event.set() self._downloads_changed_event.set()
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None: def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
with self._lock: with self._lock:
if install_job := self._download_cache.pop(download_job.id, None): install_job = self._download_cache.pop(download_job.source, None)
self._downloads_changed_event.set() if not install_job:
# if install job has already registered an error, then do not replace its status with cancelled return
if not install_job.errored: self._downloads_changed_event.set()
install_job.cancel() self._logger.warning(f"Model download canceled: {download_job.source}")
# if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored:
install_job.cancel()
self._cancel_download_parts(install_job)
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._downloads_changed_event.set() self._downloads_changed_event.set()
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
# do not lock here because it gets called within a locked context
for s in install_job.download_parts:
self._download_queue.cancel_job(s)
if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._put_in_queue(install_job)
# ------------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus # Internal methods that put events on the event bus
@@ -873,18 +859,8 @@ class ModelInstallService(ModelInstallServiceBase):
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_started(job) self._event_bus.emit_model_install_started(job)
def _signal_job_download_started(self, job: ModelInstallJob) -> None:
if self._event_bus:
assert job._multifile_job is not None
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_download_started(job)
def _signal_job_downloading(self, job: ModelInstallJob) -> None: def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus: if self._event_bus:
assert job._multifile_job is not None
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_download_progress(job) self._event_bus.emit_model_install_download_progress(job)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None: def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
@@ -899,8 +875,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Model install complete: {job.source}") self._logger.info(f"Model install complete: {job.source}")
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}") self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
if self._event_bus: if self._event_bus:
assert job.local_path is not None
assert job.config_out is not None
self._event_bus.emit_model_install_complete(job) self._event_bus.emit_model_install_complete(job)
def _signal_job_errored(self, job: ModelInstallJob) -> None: def _signal_job_errored(self, job: ModelInstallJob) -> None:
@@ -916,13 +890,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._event_bus.emit_model_install_cancelled(job) self._event_bus.emit_model_install_cancelled(job)
@staticmethod @staticmethod
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]: def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
"""
Return a metadata fetcher appropriate for provided url.
This used to be more useful, but the number of supported model
sources has been reduced to HuggingFace alone.
"""
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()): if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'") raise ValueError(f"Unsupported model source: '{url}'")

View File

@@ -1,6 +1,6 @@
"""Initialization file for model load service module.""" """Initialization file for model load service module."""
from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase from .model_load_base import ModelLoadServiceBase
from invokeai.app.services.model_load.model_load_default import ModelLoadService from .model_load_default import ModelLoadService
__all__ = ["ModelLoadServiceBase", "ModelLoadService"] __all__ = ["ModelLoadServiceBase", "ModelLoadService"]

View File

@@ -2,11 +2,11 @@
"""Base class for model loader.""" """Base class for model loader."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from typing import Optional
from typing import Callable, Optional
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
@@ -27,25 +27,7 @@ class ModelLoadServiceBase(ABC):
def ram_cache(self) -> ModelCacheBase[AnyModel]: def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader.""" """Return the RAM cache used by this loader."""
@property
@abstractmethod @abstractmethod
def load_model_from_path( def convert_cache(self) -> ModelConvertCacheBase:
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None """Return the checkpoint convert cache used by this loader."""
) -> LoadedModelWithoutConfig:
"""
Load the model file or directory located at the indicated Path.
This will load an arbitrary model file into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
LoadedModel, but without the config attribute.
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
Returns:
A LoadedModel object.
"""

View File

@@ -1,28 +1,22 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Implementation of model loader service.""" """Implementation of model loader service."""
from pathlib import Path from typing import Optional, Type
from typing import Callable, Optional, Type
from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file
from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import ( from invokeai.backend.model_manager.load import (
LoadedModel, LoadedModel,
LoadedModelWithoutConfig,
ModelLoaderRegistry, ModelLoaderRegistry,
ModelLoaderRegistryBase, ModelLoaderRegistryBase,
) )
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase
class ModelLoadService(ModelLoadServiceBase): class ModelLoadService(ModelLoadServiceBase):
"""Wrapper around ModelLoaderRegistry.""" """Wrapper around ModelLoaderRegistry."""
@@ -31,6 +25,7 @@ class ModelLoadService(ModelLoadServiceBase):
self, self,
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
ram_cache: ModelCacheBase[AnyModel], ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry, registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
): ):
"""Initialize the model load service.""" """Initialize the model load service."""
@@ -39,6 +34,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._logger = logger self._logger = logger
self._app_config = app_config self._app_config = app_config
self._ram_cache = ram_cache self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._registry = registry self._registry = registry
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
@@ -49,6 +45,11 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader.""" """Return the RAM cache used by this loader."""
return self._ram_cache return self._ram_cache
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
return self._convert_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
""" """
Given a model's configuration, load it and return the LoadedModel object. Given a model's configuration, load it and return the LoadedModel object.
@@ -67,47 +68,10 @@ class ModelLoadService(ModelLoadServiceBase):
app_config=self._app_config, app_config=self._app_config,
logger=self._logger, logger=self._logger,
ram_cache=self._ram_cache, ram_cache=self._ram_cache,
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type) ).load_model(model_config, submodel_type)
if hasattr(self, "_invoker"): if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type) self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
return loaded_model return loaded_model
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache
try:
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
except IndexError:
pass
def torch_load_file(checkpoint: Path) -> AnyModel:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
result = torch_load(checkpoint, map_location="cpu")
return result
def diffusers_load_directory(directory: Path) -> AnyModel:
load_class = GenericDiffusersLoader(
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self.convert_cache,
).get_hf_load_class(directory)
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
loader = loader or (
diffusers_load_directory
if model_path.is_dir()
else torch_load_file
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
assert loader is not None
raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))

View File

@@ -1,9 +1,10 @@
"""Initialization file for model manager service.""" """Initialization file for model manager service."""
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService, ModelManagerServiceBase
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load import LoadedModel
from .model_manager_default import ModelManagerService, ModelManagerServiceBase
__all__ = [ __all__ = [
"ModelManagerServiceBase", "ModelManagerServiceBase",
"ModelManagerService", "ModelManagerService",

View File

@@ -5,13 +5,14 @@ from abc import ABC, abstractmethod
import torch import torch
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_base import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase from ..config import InvokeAIAppConfig
from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallServiceBase
from ..model_load import ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase
class ModelManagerServiceBase(ABC): class ModelManagerServiceBase(ABC):

View File

@@ -6,20 +6,19 @@ from typing import Optional
import torch import torch
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_base import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.app.services.model_install.model_install_default import ModelInstallService
from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase
from invokeai.app.services.model_load.model_load_default import ModelLoadService
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase
from invokeai.backend.model_manager.load import ModelCache, ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase
from .model_manager_base import ModelManagerServiceBase
class ModelManagerService(ModelManagerServiceBase): class ModelManagerService(ModelManagerServiceBase):
""" """
@@ -87,9 +86,11 @@ class ModelManagerService(ModelManagerServiceBase):
logger=logger, logger=logger,
execution_device=execution_device or TorchDevice.choose_torch_device(), execution_device=execution_device or TorchDevice.choose_torch_device(),
) )
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService( loader = ModelLoadService(
app_config=app_config, app_config=app_config,
ram_cache=ram_cache, ram_cache=ram_cache,
convert_cache=convert_cache,
registry=ModelLoaderRegistry, registry=ModelLoaderRegistry,
) )
installer = ModelInstallService( installer = ModelInstallService(

View File

@@ -12,13 +12,15 @@ from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
ModelFormat, ModelFormat,
ModelType, ModelType,
)
from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
ModelVariantType, ModelVariantType,
SchedulerPredictionType, SchedulerPredictionType,
) )

View File

@@ -40,24 +40,12 @@ Typical usage:
""" """
import json import json
import logging
import sqlite3 import sqlite3
from math import ceil from math import ceil
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
import pydantic
from invokeai.app.services.model_records.model_records_base import (
DuplicateModelException,
ModelRecordChanges,
ModelRecordOrderBy,
ModelRecordServiceBase,
ModelSummary,
UnknownModelException,
)
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
@@ -66,11 +54,21 @@ from invokeai.backend.model_manager.config import (
ModelType, ModelType,
) )
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import (
DuplicateModelException,
ModelRecordChanges,
ModelRecordOrderBy,
ModelRecordServiceBase,
ModelSummary,
UnknownModelException,
)
class ModelRecordServiceSQL(ModelRecordServiceBase): class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database.""" """Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase, logger: logging.Logger): def __init__(self, db: SqliteDatabase):
""" """
Initialize a new object from preexisting sqlite3 connection and threading lock objects. Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@@ -79,7 +77,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
super().__init__() super().__init__()
self._db = db self._db = db
self._cursor = db.conn.cursor() self._cursor = db.conn.cursor()
self._logger = logger
@property @property
def db(self) -> SqliteDatabase: def db(self) -> SqliteDatabase:
@@ -295,20 +292,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
tuple(bindings), tuple(bindings),
) )
result = self._cursor.fetchall() result = self._cursor.fetchall()
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
# Parse the model configs.
results: list[AnyModelConfig] = []
for row in result:
try:
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
except pydantic.ValidationError:
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type.
self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})")
else:
results.append(model_config)
return results return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:

View File

@@ -1,6 +1,7 @@
from invokeai.app.services.names.names_base import NameServiceBase
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from .names_base import NameServiceBase
class SimpleNameService(NameServiceBase): class SimpleNameService(NameServiceBase):
"""Creates image names from UUIDs.""" """Creates image names from UUIDs."""

View File

@@ -13,24 +13,24 @@ from invokeai.app.services.events.events_common import (
register_events, register_events,
) )
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_processor.session_processor_base import ( from invokeai.app.services.session_processor.session_processor_base import (
InvocationServices,
OnAfterRunNode, OnAfterRunNode,
OnAfterRunSession, OnAfterRunSession,
OnBeforeRunNode, OnBeforeRunNode,
OnBeforeRunSession, OnBeforeRunSession,
OnNodeError, OnNodeError,
OnNonFatalProcessorError, OnNonFatalProcessorError,
SessionProcessorBase,
SessionRunnerBase,
) )
from invokeai.app.services.session_processor.session_processor_common import CanceledException, SessionProcessorStatus from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError
from invokeai.app.services.shared.graph import NodeInputError from invokeai.app.services.shared.graph import NodeInputError
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus
class DefaultSessionRunner(SessionRunnerBase): class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations.""" """Processes a single session's invocations."""

View File

@@ -37,14 +37,10 @@ class SqliteSessionQueue(SessionQueueBase):
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self.__invoker = invoker self.__invoker = invoker
self._set_in_progress_to_canceled() self._set_in_progress_to_canceled()
if self.__invoker.services.configuration.clear_queue_on_startup: prune_result = self.prune(DEFAULT_QUEUE_ID)
clear_result = self.clear(DEFAULT_QUEUE_ID)
if clear_result.deleted > 0: if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items") self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
else:
prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, db: SqliteDatabase) -> None: def __init__(self, db: SqliteDatabase) -> None:
super().__init__() super().__init__()

View File

@@ -2,19 +2,18 @@
import copy import copy
import itertools import itertools
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx import networkx as nx
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
GetCoreSchemaHandler,
GetJsonSchemaHandler, GetJsonSchemaHandler,
ValidationError, ValidationError,
field_validator, field_validator,
) )
from pydantic.fields import Field from pydantic.fields import Field
from pydantic.json_schema import JsonSchemaValue from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema from pydantic_core import CoreSchema
# Importing * is bad karma but needed here for node detection # Importing * is bad karma but needed here for node detection
from invokeai.app.invocations import * # noqa: F401 F403 from invokeai.app.invocations import * # noqa: F401 F403
@@ -278,58 +277,73 @@ class CollectInvocation(BaseInvocation):
return CollectInvocationOutput(collection=copy.copy(self.collection)) return CollectInvocationOutput(collection=copy.copy(self.collection))
class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocation.get_invocations()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class Graph(BaseModel): class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid_string) id: str = Field(description="The id of this graph", default_factory=uuid_string)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict) nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
edges: list[Edge] = Field( edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph", description="The connections between nodes and their fields in this graph",
default_factory=list, default_factory=list,
) )
@field_validator("nodes", mode="plain")
@classmethod
def validate_nodes(cls, v: dict[str, Any]):
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
# Invocations register themselves as their python modules are executed. The union of all invocations is
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
#
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
# invocations will cause a graph to fail if they are used.
#
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
#
# This same pattern is used in `GraphExecutionState`.
nodes: dict[str, BaseInvocation] = {}
typeadapter = BaseInvocation.get_typeadapter()
for node_id, node in v.items():
nodes[node_id] = typeadapter.validate_python(node)
return nodes
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
# the generated schema as options for the `nodes` field.
#
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
# expected.
#
# You might be tempted to do something like this:
#
# ```py
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
# delattr(cloned_model, "validate_nodes")
# cloned_model.model_rebuild(force=True)
# json_schema = handler(cloned_model.__pydantic_core_schema__)
# ```
#
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
#
# This same pattern is used in `GraphExecutionState`.
class Graph(BaseModel):
id: Optional[str] = Field(default=None, description="The id of this graph")
nodes: dict[
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
] = Field(description="The nodes in this graph")
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
json_schema = handler(Graph.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def add_node(self, node: BaseInvocation) -> None: def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph """Adds a node to a graph
@@ -652,7 +666,7 @@ class Graph(BaseModel):
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list # Input type must be a list
if get_origin(input_field) is not list: if get_origin(input_field) != list:
return False return False
# Validate that all outputs match the input type # Validate that all outputs match the input type
@@ -760,7 +774,7 @@ class GraphExecutionState(BaseModel):
) )
# The results of executed nodes # The results of executed nodes
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict) results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes # Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
@@ -777,12 +791,52 @@ class GraphExecutionState(BaseModel):
default_factory=dict, default_factory=dict,
) )
@field_validator("results", mode="plain")
@classmethod
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
results: dict[str, BaseInvocationOutput] = {}
typeadapter = BaseInvocationOutput.get_typeadapter()
for result_id, result in v.items():
results[result_id] = typeadapter.validate_python(result)
return results
@field_validator("graph") @field_validator("graph")
def graph_is_valid(cls, v: Graph): def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid""" """Validates that the graph is valid"""
v.validate_self() v.validate_self()
return v return v
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state")
graph: Graph = Field(description="The graph being executed")
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
executed: set[str] = Field(description="The set of node ids that have been executed")
executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution"
)
results: dict[
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
] = Field(description="The results of node executions")
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
prepared_source_mapping: dict[str, str] = Field(
description="The map of prepared nodes to original graph nodes"
)
source_prepared_mapping: dict[str, set[str]] = Field(
description="The map of original graph nodes to prepared nodes"
)
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def next(self) -> Optional[BaseInvocation]: def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute.""" """Gets the next node ready to execute."""

View File

@@ -3,7 +3,6 @@ from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional, Union from typing import TYPE_CHECKING, Callable, Optional, Union
from PIL.Image import Image from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from torch import Tensor from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.constants import IMAGE_MODES
@@ -15,15 +14,8 @@ from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
AnyModel, from invokeai.backend.model_manager.load.load_base import LoadedModel
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@@ -328,10 +320,8 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface):
"""Common API for loading, downloading and managing models."""
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool: def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Check if a model exists. """Checks if a model exists.
Args: Args:
identifier: The key or ModelField representing the model. identifier: The key or ModelField representing the model.
@@ -341,13 +331,13 @@ class ModelsInterface(InvocationContextInterface):
""" """
if isinstance(identifier, str): if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier) return self._services.model_manager.store.exists(identifier)
else:
return self._services.model_manager.store.exists(identifier.key) return self._services.model_manager.store.exists(identifier.key)
def load( def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
) -> LoadedModel: ) -> LoadedModel:
"""Load a model. """Loads a model.
Args: Args:
identifier: The key or ModelField representing the model. identifier: The key or ModelField representing the model.
@@ -371,7 +361,7 @@ class ModelsInterface(InvocationContextInterface):
def load_by_attrs( def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel: ) -> LoadedModel:
"""Load a model by its attributes. """Loads a model by its attributes.
Args: Args:
name: Name of the model. name: Name of the model.
@@ -394,7 +384,7 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.load.load_model(configs[0], submodel_type) return self._services.model_manager.load.load_model(configs[0], submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Get a model's config. """Gets a model's config.
Args: Args:
identifier: The key or ModelField representing the model. identifier: The key or ModelField representing the model.
@@ -404,11 +394,11 @@ class ModelsInterface(InvocationContextInterface):
""" """
if isinstance(identifier, str): if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier) return self._services.model_manager.store.get_model(identifier)
else:
return self._services.model_manager.store.get_model(identifier.key) return self._services.model_manager.store.get_model(identifier.key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]: def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Search for models by path. """Searches for models by path.
Args: Args:
path: The path to search for. path: The path to search for.
@@ -425,7 +415,7 @@ class ModelsInterface(InvocationContextInterface):
type: Optional[ModelType] = None, type: Optional[ModelType] = None,
format: Optional[ModelFormat] = None, format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]: ) -> list[AnyModelConfig]:
"""Search for models by attributes. """Searches for models by attributes.
Args: Args:
name: The name to search for (exact match). name: The name to search for (exact match).
@@ -444,72 +434,6 @@ class ModelsInterface(InvocationContextInterface):
model_format=format, model_format=format,
) )
def download_and_cache_model(
self,
source: str | AnyHttpUrl,
) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
This can be used to single-file install models and other resources of arbitrary types
which should not get registered with the database. If the model is already
installed, the cached path will be returned. Otherwise it will be downloaded.
Args:
source: A URL that points to the model, or a huggingface repo_id.
Returns:
Path to the downloaded model
"""
return self._services.model_manager.install.download_and_cache_model(source=source)
def load_local_model(
self,
model_path: Path,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
Load the model file located at the indicated path
If a loader callable is provided, it will be invoked to load the model. Otherwise,
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args:
path: A model Path
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModelWithoutConfig object.
"""
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
def load_remote_model(
self,
source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
Download, cache, and load the model file located at the indicated URL or repo_id.
If the model is already downloaded, it will be loaded from the cache.
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args:
source: A URL or huggingface repoid.
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModelWithoutConfig object.
"""
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
class ConfigInterface(InvocationContextInterface): class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig: def get(self) -> InvokeAIAppConfig:

View File

@@ -13,9 +13,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -46,9 +43,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_8(app_config=config)) migrator.register_migration(build_migration_8(app_config=config))
migrator.register_migration(build_migration_9()) migrator.register_migration(build_migration_9())
migrator.register_migration(build_migration_10()) migrator.register_migration(build_migration_10())
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
migrator.register_migration(build_migration_12(app_config=config))
migrator.register_migration(build_migration_13())
migrator.run_migrations() migrator.run_migrations()
return db return db

View File

@@ -1,75 +0,0 @@
import shutil
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
LEGACY_CORE_MODELS = [
# OpenPose
"any/annotators/dwpose/yolox_l.onnx",
"any/annotators/dwpose/dw-ll_ucoco_384.onnx",
# DepthAnything
"any/annotators/depth_anything/depth_anything_vitl14.pth",
"any/annotators/depth_anything/depth_anything_vitb14.pth",
"any/annotators/depth_anything/depth_anything_vits14.pth",
# Lama inpaint
"core/misc/lama/lama.pt",
# RealESRGAN upscale
"core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
"core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
"core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
]
class Migration11Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._remove_convert_cache()
self._remove_downloaded_models()
self._remove_unused_core_models()
def _remove_convert_cache(self) -> None:
"""Rename models/.cache to models/.convert_cache."""
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
shutil.rmtree(legacy_convert_path, ignore_errors=True)
def _remove_downloaded_models(self) -> None:
"""Remove models from their old locations; they will re-download when needed."""
self._logger.info(
"Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
)
for model_path in LEGACY_CORE_MODELS:
legacy_dest_path = self._app_config.models_path / model_path
legacy_dest_path.unlink(missing_ok=True)
def _remove_unused_core_models(self) -> None:
"""Remove unused core models and their directories."""
self._logger.info("Removing defunct core models.")
for dir in ["face_restoration", "misc", "upscaling"]:
path_to_remove = self._app_config.models_path / "core" / dir
shutil.rmtree(path_to_remove, ignore_errors=True)
shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True)
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""
Build the migration from database version 10 to 11.
This migration does the following:
- Moves "core" models previously downloaded with download_with_progress_bar() into new
"models/.download_cache" directory.
- Renames "models/.cache" to "models/.convert_cache".
"""
migration_11 = Migration(
from_version=10,
to_version=11,
callback=Migration11Callback(app_config=app_config, logger=logger),
)
return migration_11

Some files were not shown because too many files have changed in this diff Show More