Compare commits

..

4 Commits

Author SHA1 Message Date
psychedelicious
e04c25eba7 Merge branch 'main' into fix/diffusers-embeddings 2023-08-01 22:14:43 +10:00
Lincoln Stein
be61ffdbf6 Merge branch 'main' into fix/diffusers-embeddings 2023-08-01 00:21:50 -04:00
Lincoln Stein
823b879329 Merge branch 'main' into fix/diffusers-embeddings 2023-07-31 21:04:08 -04:00
Lincoln Stein
17c901aaf7 fix diffusers-style textual embeddings
- also fix a couple places where the wrong base was used for relative model paths
2023-07-31 21:00:12 -04:00
212 changed files with 4230 additions and 6313 deletions

View File

@@ -1,14 +1,13 @@
name: style checks name: Black # TODO: add isort and flake8 later
# just formatting for now
# TODO: add isort and flake8 later
on: on:
pull_request: pull_request: {}
push: push:
branches: main branches: master
tags: "*"
jobs: jobs:
black: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
@@ -20,7 +19,8 @@ jobs:
- name: Install dependencies with pip - name: Install dependencies with pip
run: | run: |
pip install black pip install --upgrade pip wheel
pip install .[test]
# - run: isort --check-only . # - run: isort --check-only .
- run: black --check . - run: black --check .

View File

@@ -0,0 +1,50 @@
name: Test invoke.py pip
# This is a dummy stand-in for the actual tests
# we don't need to run python tests on non-Python changes
# But PRs require passing tests to be mergeable
on:
pull_request:
paths:
- '**'
- '!pyproject.toml'
- '!invokeai/**'
- '!tests/**'
- 'invokeai/frontend/web/**'
merge_group:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
if: github.event.pull_request.draft == false
strategy:
matrix:
python-version:
- '3.10'
pytorch:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
- pytorch: linux-rocm-5_2
os: ubuntu-22.04
- pytorch: linux-cpu
os: ubuntu-22.04
- pytorch: macos-default
os: macOS-12
- pytorch: windows-cpu
os: windows-2022
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
steps:
- name: skip
run: echo "no build required"

View File

@@ -3,7 +3,16 @@ on:
push: push:
branches: branches:
- 'main' - 'main'
paths:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
pull_request: pull_request:
paths:
- 'pyproject.toml'
- 'invokeai/**'
- 'tests/**'
- '!invokeai/frontend/web/**'
types: types:
- 'ready_for_review' - 'ready_for_review'
- 'opened' - 'opened'
@@ -56,23 +65,10 @@ jobs:
id: checkout-sources id: checkout-sources
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Check for changed python files
id: changed-files
uses: tj-actions/changed-files@v37
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: set test prompt to main branch validation - name: set test prompt to main branch validation
if: steps.changed-files.outputs.python_any_changed == 'true'
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }} run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
- name: setup python - name: setup python
if: steps.changed-files.outputs.python_any_changed == 'true'
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
@@ -80,7 +76,6 @@ jobs:
cache-dependency-path: pyproject.toml cache-dependency-path: pyproject.toml
- name: install invokeai - name: install invokeai
if: steps.changed-files.outputs.python_any_changed == 'true'
env: env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }} PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: > run: >
@@ -88,7 +83,6 @@ jobs:
--editable=".[test]" --editable=".[test]"
- name: run pytest - name: run pytest
if: steps.changed-files.outputs.python_any_changed == 'true'
id: run-pytest id: run-pytest
run: pytest run: pytest

View File

@@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
_For Windows/Linux with an NVIDIA GPU:_ _For Windows/Linux with an NVIDIA GPU:_
```terminal ```terminal
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118 pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
``` ```
_For Linux with an AMD GPU:_ _For Linux with an AMD GPU:_
@@ -184,9 +184,8 @@ the command `npm install -g yarn` if needed)
6. Configure InvokeAI and install a starting set of image generation models (you only need to do this once): 6. Configure InvokeAI and install a starting set of image generation models (you only need to do this once):
```terminal ```terminal
invokeai-configure --root . invokeai-configure
``` ```
Don't miss the dot at the end!
7. Launch the web server (do it every time you run InvokeAI): 7. Launch the web server (do it every time you run InvokeAI):
@@ -194,9 +193,15 @@ the command `npm install -g yarn` if needed)
invokeai-web invokeai-web
``` ```
8. Point your browser to http://localhost:9090 to bring up the web interface. 8. Build Node.js assets
9. Type `banana sushi` in the box on the top left and click `Invoke`. ```terminal
cd invokeai/frontend/web/
yarn vite build
```
9. Point your browser to http://localhost:9090 to bring up the web interface.
10. Type `banana sushi` in the box on the top left and click `Invoke`.
Be sure to activate the virtual environment each time before re-launching InvokeAI, Be sure to activate the virtual environment each time before re-launching InvokeAI,
using `source .venv/bin/activate` or `.venv\Scripts\activate`. using `source .venv/bin/activate` or `.venv\Scripts\activate`.
@@ -306,30 +311,13 @@ InvokeAI. The second will prepare the 2.3 directory for use with 3.0.
You may now launch the WebUI in the usual way, by selecting option [1] You may now launch the WebUI in the usual way, by selecting option [1]
from the launcher script from the launcher script
#### Migrating Images #### Migration Caveats
The migration script will migrate your invokeai settings and models, The migration script will migrate your invokeai settings and models,
including textual inversion models, LoRAs and merges that you may have including textual inversion models, LoRAs and merges that you may have
installed previously. However it does **not** migrate the generated installed previously. However it does **not** migrate the generated
images stored in your 2.3-format outputs directory. To do this, you images stored in your 2.3-format outputs directory. You will need to
need to run an additional step: manually import selected images into the 3.0 gallery via drag-and-drop.
1. From a working InvokeAI 3.0 root directory, start the launcher and
enter menu option [8] to open the "developer's console".
2. At the developer's console command line, type the command:
```bash
invokeai-import-images
```
3. This will lead you through the process of confirming the desired
source and destination for the imported images. The images will
appear in the gallery board of your choice, and contain the
original prompt, model name, and other parameters used to generate
the image.
(Many kudos to **techjedi** for contributing this script.)
## Hardware Requirements ## Hardware Requirements

View File

@@ -264,7 +264,7 @@ experimental versions later.
you can create several levels of subfolders and drop your models into you can create several levels of subfolders and drop your models into
whichever ones you want. whichever ones you want.
- ***LICENSE*** - ***Autoimport FolderLICENSE***
At the bottom of the screen you will see a checkbox for accepting At the bottom of the screen you will see a checkbox for accepting
the CreativeML Responsible AI Licenses. You need to accept the license the CreativeML Responsible AI Licenses. You need to accept the license
@@ -471,7 +471,7 @@ Then type the following commands:
=== "NVIDIA System" === "NVIDIA System"
```bash ```bash
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118 pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu117
pip install xformers pip install xformers
``` ```

View File

@@ -148,7 +148,7 @@ manager, please follow these steps:
=== "CUDA (NVidia)" === "CUDA (NVidia)"
```bash ```bash
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118 pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
``` ```
=== "ROCm (AMD)" === "ROCm (AMD)"
@@ -192,11 +192,9 @@ manager, please follow these steps:
your outputs. your outputs.
```terminal ```terminal
invokeai-configure --root . invokeai-configure
``` ```
Don't miss the dot at the end of the command!
The script `invokeai-configure` will interactively guide you through the The script `invokeai-configure` will interactively guide you through the
process of downloading and installing the weights files needed for InvokeAI. process of downloading and installing the weights files needed for InvokeAI.
Note that the main Stable Diffusion weights file is protected by a license Note that the main Stable Diffusion weights file is protected by a license
@@ -227,6 +225,12 @@ manager, please follow these steps:
!!! warning "Make sure that the virtual environment is activated, which should create `(.venv)` in front of your prompt!" !!! warning "Make sure that the virtual environment is activated, which should create `(.venv)` in front of your prompt!"
=== "CLI"
```bash
invokeai
```
=== "local Webserver" === "local Webserver"
```bash ```bash
@@ -239,12 +243,6 @@ manager, please follow these steps:
invokeai --web --host 0.0.0.0 invokeai --web --host 0.0.0.0
``` ```
=== "CLI"
```bash
invokeai
```
If you choose the run the web interface, point your browser at If you choose the run the web interface, point your browser at
http://localhost:9090 in order to load the GUI. http://localhost:9090 in order to load the GUI.
@@ -312,7 +310,7 @@ installation protocol (important!)
=== "CUDA (NVidia)" === "CUDA (NVidia)"
```bash ```bash
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118 pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
``` ```
=== "ROCm (AMD)" === "ROCm (AMD)"
@@ -356,7 +354,7 @@ you can do so using this unsupported recipe:
mkdir ~/invokeai mkdir ~/invokeai
conda create -n invokeai python=3.10 conda create -n invokeai python=3.10
conda activate invokeai conda activate invokeai
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118 pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
invokeai-configure --root ~/invokeai invokeai-configure --root ~/invokeai
invokeai --root ~/invokeai --web invokeai --root ~/invokeai --web
``` ```

View File

@@ -34,11 +34,11 @@ directly from NVIDIA. **Do not try to install Ubuntu's
nvidia-cuda-toolkit package. It is out of date and will cause nvidia-cuda-toolkit package. It is out of date and will cause
conflicts among the NVIDIA driver and binaries.** conflicts among the NVIDIA driver and binaries.**
Go to [CUDA Toolkit Go to [CUDA Toolkit 11.7
Downloads](https://developer.nvidia.com/cuda-downloads), and use the Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive),
target selection wizard to choose your operating system, hardware and use the target selection wizard to choose your operating system,
platform, and preferred installation method (e.g. "local" versus hardware platform, and preferred installation method (e.g. "local"
"network"). versus "network").
This will provide you with a downloadable install file or, depending This will provide you with a downloadable install file or, depending
on your choices, a recipe for downloading and running a install shell on your choices, a recipe for downloading and running a install shell
@@ -61,7 +61,7 @@ Runtime Site](https://developer.nvidia.com/nvidia-container-runtime)
When installing torch and torchvision manually with `pip`, remember to provide When installing torch and torchvision manually with `pip`, remember to provide
the argument `--extra-index-url the argument `--extra-index-url
https://download.pytorch.org/whl/cu118` as described in the [Manual https://download.pytorch.org/whl/cu117` as described in the [Manual
Installation Guide](020_INSTALL_MANUAL.md). Installation Guide](020_INSTALL_MANUAL.md).
## :simple-amd: ROCm ## :simple-amd: ROCm

View File

@@ -124,7 +124,7 @@ installation. Examples:
invokeai-model-install --list controlnet invokeai-model-install --list controlnet
# (install the model at the indicated URL) # (install the model at the indicated URL)
invokeai-model-install --add https://civitai.com/api/download/models/128713 invokeai-model-install --add http://civitai.com/2860
# (delete the named model) # (delete the named model)
invokeai-model-install --delete sd-1/main/analog-diffusion invokeai-model-install --delete sd-1/main/analog-diffusion

View File

@@ -28,21 +28,18 @@ command line, then just be sure to activate it's virtual environment.
Then run the following three commands: Then run the following three commands:
```sh ```sh
pip install xformers~=0.0.19 pip install xformers==0.0.16rc425
pip install triton # WON'T WORK ON WINDOWS pip install triton
python -m xformers.info output python -m xformers.info output
``` ```
The first command installs `xformers`, the second installs the The first command installs `xformers`, the second installs the
`triton` training accelerator, and the third prints out the `xformers` `triton` training accelerator, and the third prints out the `xformers`
installation status. On Windows, please omit the `triton` package, installation status. If all goes well, you'll see a report like the
which is not available on that platform.
If all goes well, you'll see a report like the
following: following:
```sh ```sh
xFormers 0.0.20 xFormers 0.0.16rc425
memory_efficient_attention.cutlassF: available memory_efficient_attention.cutlassF: available
memory_efficient_attention.cutlassB: available memory_efficient_attention.cutlassB: available
memory_efficient_attention.flshattF: available memory_efficient_attention.flshattF: available
@@ -51,28 +48,22 @@ memory_efficient_attention.smallkF: available
memory_efficient_attention.smallkB: available memory_efficient_attention.smallkB: available
memory_efficient_attention.tritonflashattF: available memory_efficient_attention.tritonflashattF: available
memory_efficient_attention.tritonflashattB: available memory_efficient_attention.tritonflashattB: available
indexing.scaled_index_addF: available
indexing.scaled_index_addB: available
indexing.index_select: available
swiglu.dual_gemm_silu: available
swiglu.gemm_fused_operand_sum: available
swiglu.fused.p.cpp: available swiglu.fused.p.cpp: available
is_triton_available: True is_triton_available: True
is_functorch_available: False is_functorch_available: False
pytorch.version: 2.0.1+cu118 pytorch.version: 1.13.1+cu117
pytorch.cuda: available pytorch.cuda: available
gpu.compute_capability: 8.9 gpu.compute_capability: 8.6
gpu.name: NVIDIA GeForce RTX 4070 gpu.name: NVIDIA RTX A2000 12GB
build.info: available build.info: available
build.cuda_version: 1108 build.cuda_version: 1107
build.python_version: 3.10.11 build.python_version: 3.10.9
build.torch_version: 2.0.1+cu118 build.torch_version: 1.13.1+cu117
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6 build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
build.env.XFORMERS_BUILD_TYPE: Release build.env.XFORMERS_BUILD_TYPE: Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
build.env.NVCC_FLAGS: None build.env.NVCC_FLAGS: None
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.20 build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.16rc425
build.nvcc_version: 11.8.89
source.privacy: open source source.privacy: open source
``` ```
@@ -92,14 +83,14 @@ installed from source. These instructions were written for a system
running Ubuntu 22.04, but other Linux distributions should be able to running Ubuntu 22.04, but other Linux distributions should be able to
adapt this recipe. adapt this recipe.
#### 1. Install CUDA Toolkit 11.8 #### 1. Install CUDA Toolkit 11.7
You will need the CUDA developer's toolkit in order to compile and You will need the CUDA developer's toolkit in order to compile and
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
package.** It is out of date and will cause conflicts among the NVIDIA package.** It is out of date and will cause conflicts among the NVIDIA
driver and binaries. Instead install the CUDA Toolkit package provided driver and binaries. Instead install the CUDA Toolkit package provided
by NVIDIA itself. Go to [CUDA Toolkit 11.8 by NVIDIA itself. Go to [CUDA Toolkit 11.7
Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive) Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive)
and use the target selection wizard to choose your platform and Linux and use the target selection wizard to choose your platform and Linux
distribution. Select an installer type of "runfile (local)" at the distribution. Select an installer type of "runfile (local)" at the
last step. last step.
@@ -110,17 +101,17 @@ example, the install script recipe for Ubuntu 22.04 running on a
x86_64 system is: x86_64 system is:
``` ```
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
sudo sh cuda_11.8.0_520.61.05_linux.run sudo sh cuda_11.7.0_515.43.04_linux.run
``` ```
Rather than cut-and-paste this example, We recommend that you walk Rather than cut-and-paste this example, We recommend that you walk
through the toolkit wizard in order to get the most up to date through the toolkit wizard in order to get the most up to date
installer for your system. installer for your system.
#### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support #### 2. Confirm/Install pyTorch 1.13 with CUDA 11.7 support
If you are using InvokeAI 3.0.2 or higher, these will already be If you are using InvokeAI 2.3 or higher, these will already be
installed. If not, you can check whether you have the needed libraries installed. If not, you can check whether you have the needed libraries
using a quick command. Activate the invokeai virtual environment, using a quick command. Activate the invokeai virtual environment,
either by entering the "developer's console", or manually with a either by entering the "developer's console", or manually with a
@@ -133,7 +124,7 @@ Then run the command:
python -c 'exec("import torch\nprint(torch.__version__)")' python -c 'exec("import torch\nprint(torch.__version__)")'
``` ```
If it prints __1.13.1+cu118__ you're good. If not, you can install the If it prints __1.13.1+cu117__ you're good. If not, you can install the
most up to date libraries with this command: most up to date libraries with this command:
```sh ```sh

View File

@@ -34,10 +34,6 @@
cudaPackages.cudnn cudaPackages.cudnn
cudaPackages.cuda_nvrtc cudaPackages.cuda_nvrtc
cudatoolkit cudatoolkit
pkgconfig
libconfig
cmake
blas
freeglut freeglut
glib glib
gperf gperf
@@ -46,12 +42,6 @@
libGLU libGLU
linuxPackages.nvidia_x11 linuxPackages.nvidia_x11
python python
(opencv4.override {
enableGtk3 = true;
enableFfmpeg = true;
enableCuda = true;
enableUnfree = true;
})
stdenv.cc stdenv.cc
stdenv.cc.cc.lib stdenv.cc.cc.lib
xorg.libX11 xorg.libX11

View File

@@ -348,7 +348,7 @@ class InvokeAiInstance:
introduction() introduction()
from invokeai.frontend.install.invokeai_configure import invokeai_configure from invokeai.frontend.install import invokeai_configure
# NOTE: currently the config script does its own arg parsing! this means the command-line switches # NOTE: currently the config script does its own arg parsing! this means the command-line switches
# from the installer will also automatically propagate down to the config script. # from the installer will also automatically propagate down to the config script.
@@ -463,10 +463,10 @@ def get_torch_source() -> (Union[str, None], str):
url = "https://download.pytorch.org/whl/cpu" url = "https://download.pytorch.org/whl/cpu"
if device == "cuda": if device == "cuda":
url = "https://download.pytorch.org/whl/cu118" url = "https://download.pytorch.org/whl/cu117"
optional_modules = "[xformers,onnx-cuda]" optional_modules = "[xformers,onnx-cuda]"
if device == "cuda_and_dml": if device == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu118" url = "https://download.pytorch.org/whl/cu117"
optional_modules = "[xformers,onnx-directml]" optional_modules = "[xformers,onnx-directml]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13 # in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@@ -2,6 +2,7 @@
from typing import Optional from typing import Optional
from logging import Logger from logging import Logger
import os
from invokeai.app.services.board_image_record_storage import ( from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage, SqliteBoardImageRecordStorage,
) )
@@ -29,7 +30,6 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService from ..services.model_manager_service import ModelManagerService
from ..services.invocation_stats import InvocationStatsService
from .events import FastAPIEventService from .events import FastAPIEventService
@@ -55,7 +55,7 @@ logger = InvokeAILogger.getLogger()
class ApiDependencies: class ApiDependencies:
"""Contains and initializes all dependencies for the API""" """Contains and initializes all dependencies for the API"""
invoker: Invoker invoker: Optional[Invoker] = None
@staticmethod @staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger): def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
@@ -68,9 +68,8 @@ class ApiDependencies:
output_folder = config.output_path output_folder = config.output_path
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_path = config.db_path db_location = config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True) db_location.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path)
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"
@@ -129,7 +128,6 @@ class ApiDependencies:
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
configuration=config, configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger, logger=logger,
) )

View File

@@ -1,30 +1,24 @@
from fastapi import Body, HTTPException from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel, Field from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"]) board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
class AddImagesToBoardResult(BaseModel):
board_id: str = Field(description="The id of the board the images were added to")
added_image_names: list[str] = Field(description="The image names that were added to the board")
class RemoveImagesFromBoardResult(BaseModel):
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
@board_images_router.post( @board_images_router.post(
"/", "/",
operation_id="add_image_to_board", operation_id="create_board_image",
responses={ responses={
201: {"description": "The image was added to a board successfully"}, 201: {"description": "The image was added to a board successfully"},
}, },
status_code=201, status_code=201,
) )
async def add_image_to_board( async def create_board_image(
board_id: str = Body(description="The id of the board to add to"), board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"), image_name: str = Body(description="The name of the image to add"),
): ):
@@ -35,78 +29,26 @@ async def add_image_to_board(
) )
return result return result
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="Failed to add image to board") raise HTTPException(status_code=500, detail="Failed to add to board")
@board_images_router.delete( @board_images_router.delete(
"/", "/",
operation_id="remove_image_from_board", operation_id="remove_board_image",
responses={ responses={
201: {"description": "The image was removed from the board successfully"}, 201: {"description": "The image was removed from the board successfully"},
}, },
status_code=201, status_code=201,
) )
async def remove_image_from_board( async def remove_board_image(
image_name: str = Body(description="The name of the image to remove", embed=True), board_id: str = Body(description="The id of the board"),
image_name: str = Body(description="The name of the image to remove"),
): ):
"""Removes an image from its board, if it had one""" """Deletes a board_image"""
try: try:
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name) result = ApiDependencies.invoker.services.board_images.remove_image_from_board(
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to remove image from board")
@board_images_router.post(
"/batch",
operation_id="add_images_to_board",
responses={
201: {"description": "Images were added to board successfully"},
},
status_code=201,
response_model=AddImagesToBoardResult,
)
async def add_images_to_board(
board_id: str = Body(description="The id of the board to add to"),
image_names: list[str] = Body(description="The names of the images to add", embed=True),
) -> AddImagesToBoardResult:
"""Adds a list of images to a board"""
try:
added_image_names: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id, image_name=image_name board_id=board_id, image_name=image_name
) )
added_image_names.append(image_name) return result
except:
pass
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="Failed to add images to board") raise HTTPException(status_code=500, detail="Failed to update board")
@board_images_router.post(
"/batch/delete",
operation_id="remove_images_from_board",
responses={
201: {"description": "Images were removed from board successfully"},
},
status_code=201,
response_model=RemoveImagesFromBoardResult,
)
async def remove_images_from_board(
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
) -> RemoveImagesFromBoardResult:
"""Removes a list of images from their board, if they had one"""
try:
removed_image_names: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_image_names.append(image_name)
except:
pass
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to remove images from board")

View File

@@ -1,20 +1,21 @@
import io import io
from typing import Optional from typing import Optional
from PIL import Image
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel from PIL import Image
from invokeai.app.invocations.metadata import ImageMetadata from invokeai.app.invocations.metadata import ImageMetadata
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageDTO, ImageDTO,
ImageRecordChanges, ImageRecordChanges,
ImageUrlsDTO, ImageUrlsDTO,
) )
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"]) images_router = APIRouter(prefix="/v1/images", tags=["images"])
@@ -24,7 +25,7 @@ IMAGE_MAX_AGE = 31536000
@images_router.post( @images_router.post(
"/upload", "/",
operation_id="upload_image", operation_id="upload_image",
responses={ responses={
201: {"description": "The image was uploaded successfully"}, 201: {"description": "The image was uploaded successfully"},
@@ -76,7 +77,7 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image") raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/i/{image_name}", operation_id="delete_image") @images_router.delete("/{image_name}", operation_id="delete_image")
async def delete_image( async def delete_image(
image_name: str = Path(description="The name of the image to delete"), image_name: str = Path(description="The name of the image to delete"),
) -> None: ) -> None:
@@ -102,7 +103,7 @@ async def clear_intermediates() -> int:
@images_router.patch( @images_router.patch(
"/i/{image_name}", "/{image_name}",
operation_id="update_image", operation_id="update_image",
response_model=ImageDTO, response_model=ImageDTO,
) )
@@ -119,7 +120,7 @@ async def update_image(
@images_router.get( @images_router.get(
"/i/{image_name}", "/{image_name}",
operation_id="get_image_dto", operation_id="get_image_dto",
response_model=ImageDTO, response_model=ImageDTO,
) )
@@ -135,7 +136,7 @@ async def get_image_dto(
@images_router.get( @images_router.get(
"/i/{image_name}/metadata", "/{image_name}/metadata",
operation_id="get_image_metadata", operation_id="get_image_metadata",
response_model=ImageMetadata, response_model=ImageMetadata,
) )
@@ -150,9 +151,8 @@ async def get_image_metadata(
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@images_router.api_route( @images_router.get(
"/i/{image_name}/full", "/{image_name}/full",
methods=["GET", "HEAD"],
operation_id="get_image_full", operation_id="get_image_full",
response_class=Response, response_class=Response,
responses={ responses={
@@ -187,7 +187,7 @@ async def get_image_full(
@images_router.get( @images_router.get(
"/i/{image_name}/thumbnail", "/{image_name}/thumbnail",
operation_id="get_image_thumbnail", operation_id="get_image_thumbnail",
response_class=Response, response_class=Response,
responses={ responses={
@@ -216,7 +216,7 @@ async def get_image_thumbnail(
@images_router.get( @images_router.get(
"/i/{image_name}/urls", "/{image_name}/urls",
operation_id="get_image_urls", operation_id="get_image_urls",
response_model=ImageUrlsDTO, response_model=ImageUrlsDTO,
) )
@@ -265,24 +265,3 @@ async def list_image_dtos(
) )
return image_dtos return image_dtos
class DeleteImagesFromListResult(BaseModel):
deleted_images: list[str]
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
async def delete_images_from_list(
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
) -> DeleteImagesFromListResult:
try:
deleted_images: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.append(image_name)
except:
pass
return DeleteImagesFromListResult(deleted_images=deleted_images)
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to delete images")

View File

@@ -104,12 +104,8 @@ async def update_model(
): # model manager moved model path during rename - don't overwrite it ): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get("path") info.path = new_info.get("path")
# replace empty string values with None/null to avoid phenomenon of vae: ''
info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
ApiDependencies.invoker.services.model_manager.update_model( ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict()
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(

View File

@@ -37,7 +37,6 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.app.services.invocation_stats import InvocationStatsService
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@@ -312,7 +311,6 @@ def invoke_cli():
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger, logger=logger,
configuration=config, configuration=config,
) )

View File

@@ -109,15 +109,12 @@ class CompelInvocation(BaseInvocation):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append( ti_list.append(
(
name,
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=name, model_name=name,
base_model=self.clip.text_encoder.base_model, base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context, context=context,
).context.model, ).context.model
)
) )
except ModelNotFoundException: except ModelNotFoundException:
# print(e) # print(e)
@@ -176,7 +173,7 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase: class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix): def run_clip_raw(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(), **clip_field.tokenizer.dict(),
context=context, context=context,
@@ -200,15 +197,12 @@ class SDXLPromptInvocationBase:
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append( ti_list.append(
(
name,
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context, context=context,
).context.model, ).context.model
)
) )
except ModelNotFoundException: except ModelNotFoundException:
# print(e) # print(e)
@@ -216,8 +210,8 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc()) # print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found') print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora( with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader(), lora_prefix text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( ), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer, tokenizer,
ti_manager, ti_manager,
@@ -253,7 +247,7 @@ class SDXLPromptInvocationBase:
return c, c_pooled, None return c, c_pooled, None
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix): def run_clip_compel(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(), **clip_field.tokenizer.dict(),
context=context, context=context,
@@ -277,15 +271,12 @@ class SDXLPromptInvocationBase:
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append( ti_list.append(
(
name,
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context, context=context,
).context.model, ).context.model
)
) )
except ModelNotFoundException: except ModelNotFoundException:
# print(e) # print(e)
@@ -293,8 +284,8 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc()) # print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found') print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora( with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader(), lora_prefix text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( ), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer, tokenizer,
ti_manager, ti_manager,
@@ -366,11 +357,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_") c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
if self.style.strip() == "": if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_") c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
else: else:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_") c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width) original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left) crop_coords = (self.crop_top, self.crop_left)
@@ -424,8 +415,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: if there will appear lora for refiner - write proper prefix c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>")
original_size = (self.original_height, self.original_width) original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left) crop_coords = (self.crop_top, self.crop_left)
@@ -477,11 +467,11 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_") c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
if self.style.strip() == "": if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_") c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
else: else:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_") c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width) original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left) crop_coords = (self.crop_top, self.crop_left)
@@ -535,8 +525,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: if there will appear lora for refiner - write proper prefix c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "<NONE>")
original_size = (self.original_height, self.original_width) original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left) crop_coords = (self.crop_top, self.crop_left)

View File

@@ -1,23 +1,26 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import contextmanager, ContextDecorator
from functools import partial from functools import partial
from typing import Literal, Optional, get_args from typing import Literal, Optional, get_args
import torch
from pydantic import Field from pydantic import Field
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .image import ImageOutput
from .model import UNetField, VaeField
from ..util.step_callback import stable_diffusion_step_callback
from ...backend.generator import Inpaint, InvokeAIGenerator from ...backend.generator import Inpaint, InvokeAIGenerator
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, VaeField
from .compel import ConditioningField
from contextlib import contextmanager, ExitStack, ContextDecorator
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())] INFILL_METHODS = Literal[tuple(infill_methods())]
@@ -181,8 +184,6 @@ class InpaintInvocation(BaseInvocation):
device = context.services.model_manager.mgr.cache.execution_device device = context.services.model_manager.mgr.cache.execution_device
dtype = context.services.model_manager.mgr.cache.precision dtype = context.services.model_manager.mgr.cache.precision
vae.to(dtype=unet.dtype)
pipeline = StableDiffusionGeneratorPipeline( pipeline = StableDiffusionGeneratorPipeline(
vae=vae, vae=vae,
text_encoder=None, text_encoder=None,
@@ -192,6 +193,8 @@ class InpaintInvocation(BaseInvocation):
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
precision="float16" if dtype == torch.float16 else "float32",
execution_device=device,
) )
yield OldModelInfo( yield OldModelInfo(

View File

@@ -3,7 +3,6 @@
from typing import Literal, Optional from typing import Literal, Optional
import numpy import numpy
import cv2
from PIL import Image, ImageFilter, ImageOps, ImageChops from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import Field from pydantic import Field
from pathlib import Path from pathlib import Path
@@ -501,7 +500,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.min image_arr = image_arr * (self.max - self.min) + self.max
lerp_image = Image.fromarray(numpy.uint8(image_arr)) lerp_image = Image.fromarray(numpy.uint8(image_arr))
@@ -651,143 +650,3 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
class ImageHueAdjustmentInvocation(BaseInvocation):
"""Adjusts the Hue of an image."""
# fmt: off
type: Literal["img_hue_adjust"] = "img_hue_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert image to HSV color space
hsv_image = numpy.array(pil_image.convert("HSV"))
# Convert hue from 0..360 to 0..256
hue = int(256 * ((self.hue % 360) / 360))
# Increment each hue and wrap around at 255
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue) % 256
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
"""Adjusts the Luminosity (Value) of an image."""
# fmt: off
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert PIL image to OpenCV format (numpy array), note color channel
# ordering is changed from RGB to BGR
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
# Convert image to HSV color space
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# Adjust the luminosity (value)
hsv_image[:, :, 2] = numpy.clip(hsv_image[:, :, 2] * self.luminosity, 0, 255)
# Convert image back to BGR color space
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageSaturationAdjustmentInvocation(BaseInvocation):
"""Adjusts the Saturation of an image."""
# fmt: off
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert PIL image to OpenCV format (numpy array), note color channel
# ordering is changed from RGB to BGR
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
# Convert image to HSV color space
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# Adjust the saturation
hsv_image[:, :, 1] = numpy.clip(hsv_image[:, :, 1] * self.saturation, 0, 255)
# Convert image back to BGR color space
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -5,27 +5,16 @@ from typing import List, Literal, Optional, Union
import einops import einops
import torch import torch
from diffusers import ControlNetModel
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField from ...backend.model_management.lora import ModelPatcher
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ConditioningData,
@@ -35,7 +24,23 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
) )
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.model_management import ModelPatcher
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from invokeai.app.util.controlnet_utils import prepare_control_image
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
DEFAULT_PRECISION = choose_precision(choose_torch_device()) DEFAULT_PRECISION = choose_precision(choose_torch_device())
@@ -226,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation):
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
precision="float16" if unet.dtype == torch.float16 else "float32",
) )
def prep_control_data( def prep_control_data(

View File

@@ -1,8 +1,7 @@
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from pydantic import Field from pydantic import BaseModel, Field
from ...version import __version__
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@@ -11,20 +10,18 @@ from invokeai.app.invocations.baseinvocation import (
) )
from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class LoRAMetadataField(BaseModelExcludeNull): class LoRAMetadataField(BaseModel):
"""LoRA metadata for an image generated in InvokeAI.""" """LoRA metadata for an image generated in InvokeAI."""
lora: LoRAModelField = Field(description="The LoRA model") lora: LoRAModelField = Field(description="The LoRA model")
weight: float = Field(description="The weight of the LoRA model") weight: float = Field(description="The weight of the LoRA model")
class CoreMetadata(BaseModelExcludeNull): class CoreMetadata(BaseModel):
"""Core generation metadata for an image generated in InvokeAI.""" """Core generation metadata for an image generated in InvokeAI."""
app_version: str = Field(default=__version__, description="The version of InvokeAI used to generate this image")
generation_mode: str = Field( generation_mode: str = Field(
description="The generation mode that output this image", description="The generation mode that output this image",
) )
@@ -73,7 +70,7 @@ class CoreMetadata(BaseModelExcludeNull):
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising") refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
class ImageMetadata(BaseModelExcludeNull): class ImageMetadata(BaseModel):
"""An image's generation metadata""" """An image's generation metadata"""
metadata: Optional[dict] = Field( metadata: Optional[dict] = Field(

View File

@@ -262,103 +262,6 @@ class LoraLoaderInvocation(BaseInvocation):
return output return output
class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
# fmt: off
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
# fmt: on
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Lora Loader",
"tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unknown lora name: {lora_name}!")
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip2')
output = SDXLLoraLoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)
if self.clip2 is not None:
output.clip2 = copy.deepcopy(self.clip2)
output.clip2.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)
return output
class VAEModelField(BaseModel): class VAEModelField(BaseModel):
"""Vae model field""" """Vae model field"""

View File

@@ -65,6 +65,7 @@ class ONNXPromptInvocation(BaseInvocation):
**self.clip.text_encoder.dict(), **self.clip.text_encoder.dict(),
) )
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack: with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
loras = [ loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) (context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
for lora in self.clip.loras for lora in self.clip.loras
@@ -75,14 +76,18 @@ class ONNXPromptInvocation(BaseInvocation):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append( ti_list.append(
( # stack.enter_context(
name, # context.services.model_manager.get_model(
# model_name=name,
# base_model=self.clip.text_encoder.base_model,
# model_type=ModelType.TextualInversion,
# )
# )
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=name, model_name=name,
base_model=self.clip.text_encoder.base_model, base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
).context.model, ).context.model
)
) )
except Exception: except Exception:
# print(e) # print(e)

View File

@@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union
from pydantic import Field, validator from pydantic import Field, validator
from ...backend.model_management import ModelType, SubModelType, ModelPatcher from ...backend.model_management import ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
@@ -293,20 +293,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
num_inference_steps = self.steps num_inference_steps = self.steps
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
do_classifier_free_guidance = True do_classifier_free_guidance = True
cross_attention_kwargs = None cross_attention_kwargs = None
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet: with unet_info as unet:
scheduler.set_timesteps(num_inference_steps, device=unet.device) scheduler.set_timesteps(num_inference_steps, device=unet.device)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
@@ -553,19 +543,9 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
context=context, context=context,
) )
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
do_classifier_free_guidance = True do_classifier_free_guidance = True
cross_attention_kwargs = None cross_attention_kwargs = None
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet: with unet_info as unet:
# apply denoising_start # apply denoising_start
num_inference_steps = self.steps num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps, device=unet.device) scheduler.set_timesteps(num_inference_steps, device=unet.device)

View File

@@ -25,6 +25,7 @@ class BoardImageRecordStorageBase(ABC):
@abstractmethod @abstractmethod
def remove_image_from_board( def remove_image_from_board(
self, self,
board_id: str,
image_name: str, image_name: str,
) -> None: ) -> None:
"""Removes an image from a board.""" """Removes an image from a board."""
@@ -153,6 +154,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def remove_image_from_board( def remove_image_from_board(
self, self,
board_id: str,
image_name: str, image_name: str,
) -> None: ) -> None:
try: try:
@@ -160,9 +162,9 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
DELETE FROM board_images DELETE FROM board_images
WHERE image_name = ?; WHERE board_id = ? AND image_name = ?;
""", """,
(image_name,), (board_id, image_name),
) )
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:

View File

@@ -31,6 +31,7 @@ class BoardImagesServiceABC(ABC):
@abstractmethod @abstractmethod
def remove_image_from_board( def remove_image_from_board(
self, self,
board_id: str,
image_name: str, image_name: str,
) -> None: ) -> None:
"""Removes an image from a board.""" """Removes an image from a board."""
@@ -92,9 +93,10 @@ class BoardImagesService(BoardImagesServiceABC):
def remove_image_from_board( def remove_image_from_board(
self, self,
board_id: str,
image_name: str, image_name: str,
) -> None: ) -> None:
self._services.board_image_records.remove_image_from_board(image_name) self._services.board_image_records.remove_image_from_board(board_id, image_name)
def get_all_board_image_names_for_board( def get_all_board_image_names_for_board(
self, self,

View File

@@ -24,10 +24,11 @@ InvokeAI:
sequential_guidance: false sequential_guidance: false
precision: float16 precision: float16
max_cache_size: 6 max_cache_size: 6
max_vram_cache_size: 0.5 max_vram_cache_size: 2.7
always_use_cpu: false always_use_cpu: false
free_gpu_mem: false free_gpu_mem: false
Features: Features:
restore: true
esrgan: true esrgan: true
patchmatch: true patchmatch: true
internet_available: true internet_available: true
@@ -164,7 +165,7 @@ import pydoc
import os import os
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig, ListConfig from omegaconf import OmegaConf, DictConfig
from pathlib import Path from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
@@ -172,7 +173,6 @@ from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_ty
INIT_FILE = Path("invokeai.yaml") INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db") DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init") LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_MAX_VRAM = 0.5
class InvokeAISettings(BaseSettings): class InvokeAISettings(BaseSettings):
@@ -189,12 +189,7 @@ class InvokeAISettings(BaseSettings):
opt = parser.parse_args(argv) opt = parser.parse_args(argv)
for name in self.__fields__: for name in self.__fields__:
if name not in self._excluded(): if name not in self._excluded():
value = getattr(opt, name) setattr(self, name, getattr(opt, name))
if isinstance(value, ListConfig):
value = list(value)
elif isinstance(value, DictConfig):
value = dict(value)
setattr(self, name, value)
def to_yaml(self) -> str: def to_yaml(self) -> str:
""" """
@@ -279,7 +274,7 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def _excluded(self) -> List[str]: def _excluded(self) -> List[str]:
# internal fields that shouldn't be exposed as command line options # internal fields that shouldn't be exposed as command line options
return ["type", "initconf"] return ["type", "initconf", "cached_root"]
@classmethod @classmethod
def _excluded_from_yaml(self) -> List[str]: def _excluded_from_yaml(self) -> List[str]:
@@ -287,10 +282,15 @@ class InvokeAISettings(BaseSettings):
return [ return [
"type", "type",
"initconf", "initconf",
"gpu_mem_reserved",
"max_loaded_models",
"version", "version",
"from_file", "from_file",
"model", "model",
"restore",
"root", "root",
"nsfw_checker",
"cached_root",
] ]
class Config: class Config:
@@ -356,7 +356,7 @@ class InvokeAISettings(BaseSettings):
def _find_root() -> Path: def _find_root() -> Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".") venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"): if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ["INVOKEAI_ROOT"]) root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]): elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
root = (venv.parent).resolve() root = (venv.parent).resolve()
else: else:
@@ -389,17 +389,21 @@ class InvokeAIAppConfig(InvokeAISettings):
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features') internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features') log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features') patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='DEPRECATED')
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance') max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance') max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED')
nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance') precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance') tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths') root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths') autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths') lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths') embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
@@ -411,7 +415,8 @@ class InvokeAIAppConfig(InvokeAISettings):
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths') outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths') from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths') use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging") log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues # note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
@@ -419,11 +424,9 @@ class InvokeAIAppConfig(InvokeAISettings):
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging") log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other") version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
cached_root : Path = Field(default=None, description="internal use only", category="DEPRECATED")
# fmt: on # fmt: on
class Config:
validate_assignment = True
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False): def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
""" """
Update settings with contents of init file, environment, and Update settings with contents of init file, environment, and
@@ -469,12 +472,15 @@ class InvokeAIAppConfig(InvokeAISettings):
""" """
Path to the runtime root directory Path to the runtime root directory
""" """
if self.root: # we cache value of root to protect against it being '.' and the cwd changing
if self.cached_root:
root = self.cached_root
elif self.root:
root = Path(self.root).expanduser().absolute() root = Path(self.root).expanduser().absolute()
else: else:
root = self.find_root().expanduser().absolute() root = self.find_root()
self.root = root # insulate ourselves from relative paths that may change self.cached_root = root
return root return self.cached_root
@property @property
def root_dir(self) -> Path: def root_dir(self) -> Path:

View File

@@ -289,10 +289,9 @@ class ImageService(ImageServiceABC):
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]: def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
try: try:
image_record = self._services.image_records.get(image_name) image_record = self._services.image_records.get(image_name)
metadata = self._services.image_records.get_metadata(image_name)
if not image_record.session_id: if not image_record.session_id:
return ImageMetadata(metadata=metadata) return ImageMetadata()
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id) session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
graph = None graph = None
@@ -304,6 +303,7 @@ class ImageService(ImageServiceABC):
self._services.logger.warn(f"Failed to parse session graph: {e}") self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None graph = None
metadata = self._services.image_records.get_metadata(image_name)
return ImageMetadata(graph=graph, metadata=metadata) return ImageMetadata(graph=graph, metadata=metadata)
except ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self._services.logger.error("Image record not found")

View File

@@ -32,7 +32,6 @@ class InvocationServices:
logger: "Logger" logger: "Logger"
model_manager: "ModelManagerServiceBase" model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC" queue: "InvocationQueueABC"
def __init__( def __init__(
@@ -48,7 +47,6 @@ class InvocationServices:
logger: "Logger", logger: "Logger",
model_manager: "ModelManagerServiceBase", model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC", queue: "InvocationQueueABC",
): ):
self.board_images = board_images self.board_images = board_images
@@ -63,5 +61,4 @@ class InvocationServices:
self.logger = logger self.logger = logger
self.model_manager = model_manager self.model_manager = model_manager
self.processor = processor self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue self.queue = queue

View File

@@ -1,223 +0,0 @@
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
"""
Usage:
statistics = InvocationStatsService(graph_execution_manager)
with statistics.collect_stats(invocation, graph_execution_state.id):
... execute graphs...
statistics.log_stats()
Typical output:
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
writes to the system log is stored in InvocationServices.performance_statistics.
"""
import time
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from dataclasses import dataclass, field
from typing import Dict
import torch
import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"
@abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
"""
Initialize the InvocationStatsService and reset counters to zero
:param graph_execution_manager: Graph execution manager for this session
"""
pass
@abstractmethod
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> AbstractContextManager:
"""
Return a context object that will capture the statistics on the execution
of invocaation. Use with: to place around the part of the code that executes the invocation.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
pass
@abstractmethod
def reset_stats(self, graph_execution_state_id: str):
"""
Reset all statistics for the indicated graph
:param graph_execution_state_id
"""
pass
@abstractmethod
def reset_all_stats(self):
"""Zero all statistics"""
pass
@abstractmethod
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
"""
pass
@abstractmethod
def log_stats(self):
"""
Write out the accumulated statistics to the log or somewhere else.
"""
pass
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsService(InvocationStatsServiceBase):
"""Accumulate performance information about a running graph. Collects time spent in each node,
as well as the maximum and current VRAM utilisation for CUDA systems"""
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
self.graph_execution_manager = graph_execution_manager
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
class StatsContext:
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def __exit__(self, *args):
self.collector.update_invocation_stats(
self.graph_id,
self.invocation.type,
time.time() - self.start_time,
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
)
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> StatsContext:
"""
Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
return self.StatsContext(invocation, graph_execution_state_id, self)
def reset_all_stats(self):
"""Zero all statistics"""
self._stats = {}
def reset_stats(self, graph_execution_id: str):
"""Zero the statistics for the indicated graph."""
try:
self._stats.pop(graph_execution_id)
except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Floating point seconds used by node's exection
"""
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type]
stats.calls += 1
stats.time_used += time_used
stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed if when the execution of the graph
is complete.
"""
completed = set()
for graph_id, node_log in self._stats.items():
current_graph_state = self.graph_execution_manager.get(graph_id)
if not current_graph_state.is_complete():
continue
total_time = 0
logger.info(f"Graph stats: {graph_id}")
logger.info("Node Calls Seconds VRAM Used")
for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
total_time += stats.time_used
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
if torch.cuda.is_available():
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
completed.add(graph_id)
for graph_id in completed:
del self._stats[graph_id]

View File

@@ -3,10 +3,9 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path from pathlib import Path
from pydantic import Field from pydantic import Field
from typing import Literal, Optional, Union, Callable, List, Tuple, TYPE_CHECKING from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from types import ModuleType from types import ModuleType
from invokeai.backend.model_management import ( from invokeai.backend.model_management import (
@@ -194,7 +193,7 @@ class ModelManagerServiceBase(ABC):
self, self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae], model_type: Union[ModelType.Main, ModelType.Vae],
) -> AddModelResult: ) -> AddModelResult:
""" """
Convert a checkpoint file into a diffusers folder, deleting the cached Convert a checkpoint file into a diffusers folder, deleting the cached
@@ -293,7 +292,7 @@ class ModelManagerService(ModelManagerServiceBase):
def __init__( def __init__(
self, self,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
logger: Logger, logger: ModuleType,
): ):
""" """
Initialize with the path to the models.yaml config file. Initialize with the path to the models.yaml config file.
@@ -397,7 +396,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_type, model_type,
) )
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" """
@@ -417,7 +416,7 @@ class ModelManagerService(ModelManagerServiceBase):
""" """
return self.mgr.list_models(base_model, model_type) return self.mgr.list_models(base_model, model_type)
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
Return information about the model using the same format as list_models() Return information about the model using the same format as list_models()
""" """
@@ -430,7 +429,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_type: ModelType, model_type: ModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False, clobber: bool = False,
) -> AddModelResult: ) -> None:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
@@ -479,7 +478,7 @@ class ModelManagerService(ModelManagerServiceBase):
self, self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae], model_type: Union[ModelType.Main, ModelType.Vae],
convert_dest_directory: Optional[Path] = Field( convert_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model" default=None, description="Optional directory location for merged model"
), ),
@@ -574,9 +573,9 @@ class ModelManagerService(ModelManagerServiceBase):
default=None, description="Base model shared by all models to be merged" default=None, description="Base model shared by all models to be merged"
), ),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"), merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: float = 0.5, alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None, interp: Optional[MergeInterpolationMethod] = None,
force: bool = False, force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = Field( merge_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model" default=None, description="Optional directory location for merged model"
), ),
@@ -634,8 +633,8 @@ class ModelManagerService(ModelManagerServiceBase):
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
new_name: Optional[str] = None, new_name: str = None,
new_base: Optional[BaseModelType] = None, new_base: BaseModelType = None,
): ):
""" """
Rename the indicated model. Can provide a new name and/or a new base. Rename the indicated model. Can provide a new name and/or a new base.

View File

@@ -1,8 +0,0 @@
from pydantic import Field
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class BoardImage(BaseModelExcludeNull):
board_id: str = Field(description="The id of the board")
image_name: str = Field(description="The name of the image")

View File

@@ -1,11 +1,10 @@
from typing import Optional, Union from typing import Optional, Union
from datetime import datetime from datetime import datetime
from pydantic import Field from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class BoardRecord(BaseModelExcludeNull): class BoardRecord(BaseModel):
"""Deserialized board record.""" """Deserialized board record."""
board_id: str = Field(description="The unique ID of the board.") board_id: str = Field(description="The unique ID of the board.")

View File

@@ -1,14 +1,13 @@
import datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from pydantic import Extra, Field, StrictBool, StrictStr from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class ImageRecord(BaseModelExcludeNull): class ImageRecord(BaseModel):
"""Deserialized image record without metadata.""" """Deserialized image record without metadata."""
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
@@ -41,7 +40,7 @@ class ImageRecord(BaseModelExcludeNull):
"""The node ID that generated this image, if it is a generated image.""" """The node ID that generated this image, if it is a generated image."""
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid): class ImageRecordChanges(BaseModel, extra=Extra.forbid):
"""A set of changes to apply to an image record. """A set of changes to apply to an image record.
Only limited changes are valid: Only limited changes are valid:
@@ -61,7 +60,7 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
"""The image's new `is_intermediate` flag.""" """The image's new `is_intermediate` flag."""
class ImageUrlsDTO(BaseModelExcludeNull): class ImageUrlsDTO(BaseModel):
"""The URLs for an image and its thumbnail.""" """The URLs for an image and its thumbnail."""
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
@@ -77,15 +76,11 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.") board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
"""The id of the board the image belongs to, if one exists.""" """The id of the board the image belongs to, if one exists."""
pass pass
def image_record_to_dto( def image_record_to_dto(
image_record: ImageRecord, image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Optional[str]
image_url: str,
thumbnail_url: str,
board_id: Optional[str],
) -> ImageDTO: ) -> ImageDTO:
"""Converts an image record to an image DTO.""" """Converts an image record to an image DTO."""
return ImageDTO( return ImageDTO(

View File

@@ -1,14 +1,13 @@
import time import time
import traceback import traceback
from threading import BoundedSemaphore, Event, Thread from threading import Event, Thread, BoundedSemaphore
import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from ..models.exceptions import CanceledException
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invocation_stats import InvocationStatsServiceBase
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException
import invokeai.backend.util.logging as logger
class DefaultInvocationProcessor(InvocationProcessorABC): class DefaultInvocationProcessor(InvocationProcessorABC):
@@ -36,8 +35,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event): def __process(self, stop_event: Event):
try: try:
self.__threadLimit.acquire() self.__threadLimit.acquire()
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get() queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
@@ -86,7 +83,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke # Invoke
try: try:
with statistics.collect_stats(invocation, graph_execution_state.id):
outputs = invocation.invoke( outputs = invocation.invoke(
InvocationContext( InvocationContext(
services=self.__invoker.services, services=self.__invoker.services,
@@ -111,13 +107,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
source_node_id=source_node_id, source_node_id=source_node_id,
result=outputs.dict(), result=outputs.dict(),
) )
statistics.log_stats()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
except CanceledException: except CanceledException:
statistics.reset_stats(graph_execution_state.id)
pass pass
except Exception as e: except Exception as e:
@@ -139,7 +133,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__, error_type=e.__class__.__name__,
error=error, error=error,
) )
statistics.reset_stats(graph_execution_state.id)
pass pass
# Check queue to see if this is canceled, and skip if so # Check queue to see if this is canceled, and skip if so

View File

@@ -20,6 +20,6 @@ class LocalUrlService(UrlServiceBase):
# These paths are determined by the routes in invokeai/app/api/routers/images.py # These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail: if thumbnail:
return f"{self._base_url}/images/i/{image_basename}/thumbnail" return f"{self._base_url}/images/{image_basename}/thumbnail"
return f"{self._base_url}/images/i/{image_basename}/full" return f"{self._base_url}/images/{image_basename}/full"

View File

@@ -18,5 +18,5 @@ SEED_MAX = np.iinfo(np.uint32).max
def get_random_seed(): def get_random_seed():
rng = np.random.default_rng(seed=None) rng = np.random.default_rng(seed=0)
return int(rng.integers(0, SEED_MAX)) return int(rng.integers(0, SEED_MAX))

View File

@@ -1,23 +0,0 @@
from typing import Any
from pydantic import BaseModel
"""
We want to exclude null values from objects that make their way to the client.
Unfortunately there is no built-in way to do this in pydantic, so we need to override the default
dict method to do this.
From https://github.com/tiangolo/fastapi/discussions/8882#discussioncomment-5154541
"""
class BaseModelExcludeNull(BaseModel):
def dict(self, *args, **kwargs) -> dict[str, Any]:
"""
Override the default dict method to exclude None values in the response
"""
kwargs.pop("exclude_none", None)
return super().dict(*args, exclude_none=True, **kwargs)
pass

View File

@@ -1,11 +1,25 @@
""" """
invokeai.backend.generator.img2img descends from .generator invokeai.backend.generator.img2img descends from .generator
""" """
from typing import Optional
import torch
from accelerate.utils import set_seed
from diffusers import logging
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
StableDiffusionGeneratorPipeline,
)
from .base import Generator from .base import Generator
class Img2Img(Generator): class Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # by get_noise()
def get_make_image( def get_make_image(
self, self,
sampler, sampler,
@@ -28,4 +42,51 @@ class Img2Img(Generator):
Returns a function returning an image derived from the prompt and the initial image Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it. Return value depends on the seed at the time you call it.
""" """
raise NotImplementedError("replaced by invokeai.app.invocations.latent.LatentsToLatentsInvocation") self.perlin = perlin
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T: torch.Tensor, seed: int):
# FIXME: use x_T for initial seeded noise
# We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match.
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
logging.set_verbosity_error() # quench safety check warnings
pipeline_output = pipeline.img2img_from_embeddings(
init_image,
strength,
steps,
conditioning_data,
noise_func=self.get_noise_like,
callback=step_callback,
seed=seed,
)
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image
def get_noise_like(self, like: torch.Tensor):
device = like.device
x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
return x

View File

@@ -377,11 +377,3 @@ class Inpaint(Img2Img):
) )
return corrected_result return corrected_result
def get_noise_like(self, like: torch.Tensor):
device = like.device
x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
return x

View File

@@ -12,7 +12,6 @@ def check_invokeai_root(config: InvokeAIAppConfig):
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found" assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found" assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
assert config.models_path.exists(), f"{config.models_path} not found" assert config.models_path.exists(), f"{config.models_path} not found"
if not config.ignore_missing_core_models:
for model in [ for model in [
"CLIP-ViT-bigG-14-laion2B-39B-b160k", "CLIP-ViT-bigG-14-laion2B-39B-b160k",
"bert-base-uncased", "bert-base-uncased",
@@ -33,10 +32,5 @@ def check_invokeai_root(config: InvokeAIAppConfig):
print( print(
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **' '** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
) )
print(
'** (To skip this check completely, add "--ignore_missing_core_models" to your CLI args. Not installing '
"these core models will prevent the loading of some or all .safetensors and .ckpt files. However, you can "
"always come back and install these core models in the future.)"
)
input("Press any key to continue...") input("Press any key to continue...")
sys.exit(0) sys.exit(0)

View File

@@ -10,17 +10,15 @@ import sys
import argparse import argparse
import io import io
import os import os
import psutil
import shutil import shutil
import textwrap import textwrap
import torch
import traceback import traceback
import yaml import yaml
import warnings import warnings
from argparse import Namespace from argparse import Namespace
from enum import Enum
from pathlib import Path from pathlib import Path
from shutil import get_terminal_size from shutil import get_terminal_size
from typing import get_type_hints
from urllib import request from urllib import request
import npyscreen import npyscreen
@@ -46,8 +44,6 @@ from invokeai.app.services.config import (
) )
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
# TO DO - Move all the frontend code into invokeai.frontend.install
from invokeai.frontend.install.widgets import ( from invokeai.frontend.install.widgets import (
SingleSelectColumns, SingleSelectColumns,
CenteredButtonPress, CenteredButtonPress,
@@ -57,7 +53,6 @@ from invokeai.frontend.install.widgets import (
CyclingForm, CyclingForm,
MIN_COLS, MIN_COLS,
MIN_LINES, MIN_LINES,
WindowTooSmallException,
) )
from invokeai.backend.install.legacy_arg_parsing import legacy_parser from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import ( from invokeai.backend.install.model_install_backend import (
@@ -66,7 +61,6 @@ from invokeai.backend.install.model_install_backend import (
ModelInstall, ModelInstall,
) )
from invokeai.backend.model_management.model_probe import ModelType, BaseModelType from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
from pydantic.error_wrappers import ValidationError
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@@ -82,13 +76,6 @@ Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path SD_Configs = config.legacy_conf_path
PRECISION_CHOICES = ["auto", "float16", "float32"] PRECISION_CHOICES = ["auto", "float16", "float32"]
GB = 1073741824 # GB in bytes
HAS_CUDA = torch.cuda.is_available()
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
MAX_VRAM /= GB
MAX_RAM = psutil.virtual_memory().total / GB
INIT_FILE_PREAMBLE = """# InvokeAI initialization file INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# This is the InvokeAI initialization file, which contains command-line default values. # This is the InvokeAI initialization file, which contains command-line default values.
@@ -99,12 +86,6 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
logger = InvokeAILogger.getLogger() logger = InvokeAILogger.getLogger()
class DummyWidgetValue(Enum):
zero = 0
true = True
false = False
# -------------------------------------------- # --------------------------------------------
def postscript(errors: None): def postscript(errors: None):
if not any(errors): if not any(errors):
@@ -395,47 +376,15 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
max_width=80, max_width=80,
scroll_exit=True, scroll_exit=True,
) )
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="RAM cache size (GB). Make this at least large enough to hold a single full model.",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.max_cache_size = self.add_widget_intelligent( self.max_cache_size = self.add_widget_intelligent(
npyscreen.Slider, IntTitleSlider,
value=clip(old_opts.max_cache_size, range=(3.0, MAX_RAM), step=0.5), name="Size of the RAM cache used for fast model switching (GB)",
out_of=round(MAX_RAM), value=old_opts.max_cache_size,
lowest=0.0, out_of=20,
step=0.5, lowest=3,
relx=8, begin_entry_at=6,
scroll_exit=True, scroll_exit=True,
) )
if HAS_CUDA:
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.max_vram_cache_size = self.add_widget_intelligent(
npyscreen.Slider,
value=clip(old_opts.max_vram_cache_size, range=(0, MAX_VRAM), step=0.25),
out_of=round(MAX_VRAM * 2) / 2,
lowest=0.0,
relx=8,
step=0.25,
scroll_exit=True,
)
else:
self.max_vram_cache_size = DummyWidgetValue.zero
self.nextrely += 1 self.nextrely += 1
self.outdir = self.add_widget_intelligent( self.outdir = self.add_widget_intelligent(
FileBox, FileBox,
@@ -452,7 +401,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
self.autoimport_dirs = {} self.autoimport_dirs = {}
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent( self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
FileBox, FileBox,
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models", name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
value=str(config.root_path / config.autoimport_dir), value=str(config.root_path / config.autoimport_dir),
select_dir=True, select_dir=True,
must_exist=False, must_exist=False,
@@ -527,7 +476,6 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
"outdir", "outdir",
"free_gpu_mem", "free_gpu_mem",
"max_cache_size", "max_cache_size",
"max_vram_cache_size",
"xformers_enabled", "xformers_enabled",
"always_use_cpu", "always_use_cpu",
]: ]:
@@ -605,16 +553,6 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
) )
# -------------------------------------
def clip(value: float, range: tuple[float, float], step: float) -> float:
minimum, maximum = range
if value < minimum:
value = minimum
if value > maximum:
value = maximum
return round(value / step) * step
# ------------------------------------- # -------------------------------------
def initialize_rootdir(root: Path, yes_to_all: bool = False): def initialize_rootdir(root: Path, yes_to_all: bool = False):
logger.info("Initializing InvokeAI runtime directory") logger.info("Initializing InvokeAI runtime directory")
@@ -654,13 +592,13 @@ def maybe_create_models_yaml(root: Path):
# ------------------------------------- # -------------------------------------
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace): def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
# parse_args() will read from init file if present
invokeai_opts = default_startup_options(initfile) invokeai_opts = default_startup_options(initfile)
invokeai_opts.root = program_opts.root invokeai_opts.root = program_opts.root
if not set_min_terminal_size(MIN_COLS, MIN_LINES): # The third argument is needed in the Windows 11 environment to
raise WindowTooSmallException( # launch a console window running this program.
"Could not increase terminal size. Try running again with a larger window or smaller font size." set_min_terminal_size(MIN_COLS, MIN_LINES)
)
# the install-models application spawns a subprocess to install # the install-models application spawns a subprocess to install
# models, and will crash unless this is set before running. # models, and will crash unless this is set before running.
@@ -716,13 +654,10 @@ def migrate_init_file(legacy_format: Path):
old = legacy_parser.parse_args([f"@{str(legacy_format)}"]) old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
new = InvokeAIAppConfig.get_config() new = InvokeAIAppConfig.get_config()
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"] fields = list(get_type_hints(InvokeAIAppConfig).keys())
for attr in fields: for attr in fields:
if hasattr(old, attr): if hasattr(old, attr):
try:
setattr(new, attr, getattr(old, attr)) setattr(new, attr, getattr(old, attr))
except ValidationError as e:
print(f"* Ignoring incompatible value for field {attr}:\n {str(e)}")
# a few places where the field names have changed and we have to # a few places where the field names have changed and we have to
# manually add in the new names/values # manually add in the new names/values
@@ -842,7 +777,6 @@ def main():
models_to_download = default_user_selections(opt) models_to_download = default_user_selections(opt)
new_init_file = config.root_path / "invokeai.yaml" new_init_file = config.root_path / "invokeai.yaml"
if opt.yes_to_all: if opt.yes_to_all:
write_default_options(opt, new_init_file) write_default_options(opt, new_init_file)
init_options = Namespace(precision="float32" if opt.full_precision else "float16") init_options = Namespace(precision="float32" if opt.full_precision else "float16")
@@ -868,8 +802,6 @@ def main():
postscript(errors=errors) postscript(errors=errors)
if not opt.yes_to_all: if not opt.yes_to_all:
input("Press any key to continue...") input("Press any key to continue...")
except WindowTooSmallException as e:
logger.error(str(e))
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nGoodbye! Come back soon.") print("\nGoodbye! Come back soon.")

View File

@@ -591,6 +591,7 @@ script, which will perform a full upgrade in place.""",
# TODO: revisit - don't rely on invokeai.yaml to exist yet! # TODO: revisit - don't rely on invokeai.yaml to exist yet!
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists() dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
if not dest_is_setup: if not dest_is_setup:
import invokeai.frontend.install.invokeai_configure
from invokeai.backend.install.invokeai_configure import initialize_rootdir from invokeai.backend.install.invokeai_configure import initialize_rootdir
initialize_rootdir(dest_root, True) initialize_rootdir(dest_root, True)

View File

@@ -13,7 +13,6 @@ import requests
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers import logging as dlogging from diffusers import logging as dlogging
import onnx import onnx
import torch
from huggingface_hub import hf_hub_url, HfFolder, HfApi from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
@@ -24,7 +23,6 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
from invokeai.backend.util import download_with_resume from invokeai.backend.util import download_with_resume
from invokeai.backend.util.devices import torch_dtype, choose_torch_device
from ..util.logging import InvokeAILogger from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@@ -101,9 +99,9 @@ class ModelInstall(object):
def __init__( def __init__(
self, self,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
model_manager: Optional[ModelManager] = None, model_manager: ModelManager = None,
access_token: Optional[str] = None, access_token: str = None,
): ):
self.config = config self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path) self.mgr = model_manager or ModelManager(config.model_conf_path)
@@ -305,7 +303,7 @@ class ModelInstall(object):
with TemporaryDirectory(dir=self.config.models_path) as staging: with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging) staging = Path(staging)
if "model_index.json" in files: if "model_index.json" in files and "unet/model.onnx" not in files:
location = self._download_hf_pipeline(repo_id, staging) # pipeline location = self._download_hf_pipeline(repo_id, staging) # pipeline
elif "unet/model.onnx" in files: elif "unet/model.onnx" in files:
location = self._download_hf_model(repo_id, files, staging) location = self._download_hf_model(repo_id, files, staging)
@@ -418,25 +416,15 @@ class ModelInstall(object):
does a save_pretrained() to the indicated staging area. does a save_pretrained() to the indicated staging area.
""" """
_, name = repo_id.split("/") _, name = repo_id.split("/")
precision = torch_dtype(choose_torch_device()) revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"]
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
model = None model = None
for variant in variants: for revision in revisions:
try: try:
model = DiffusionPipeline.from_pretrained( model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None)
repo_id, except: # most errors are due to fp16 not being present. Fix this to catch other errors
variant=variant, pass
torch_dtype=precision,
safety_checker=None,
)
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
if "fp16" not in str(e):
print(e)
if model: if model:
break break
if not model: if not model:
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.") logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
return None return None

View File

@@ -13,4 +13,3 @@ from .models import (
DuplicateModelException, DuplicateModelException,
) )
from .model_merge import ModelMerger, MergeInterpolationMethod from .model_merge import ModelMerger, MergeInterpolationMethod
from .lora import ModelPatcher

View File

@@ -20,6 +20,424 @@ from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
# TODO: rename and split this file
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: dict,
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return (
op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
)
* multiplier
* scale
)
def get_weight(self):
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1 = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2 = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2 = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2 = values["lokr_t2"]
else:
self.t2 = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
def get_weight(self):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoRAModel: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> LoRAModel:
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
else:
# TODO: diff/ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
return
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
""" """
loras = [ loras = [
(lora_model1, 0.7), (lora_model1, 0.7),
@@ -98,26 +516,6 @@ class ModelPatcher:
with cls.apply_lora(text_encoder, loras, "lora_te_"): with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield
@classmethod @classmethod
@contextmanager @contextmanager
def apply_lora( def apply_lora(
@@ -143,7 +541,7 @@ class ModelPatcher:
# with torch.autocast(device_type="cpu"): # with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32) layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale layer_weight = layer.get_weight() * lora_weight * layer_scale
if module.weight.shape != layer_weight.shape: if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris # TODO: debug on lycoris
@@ -164,7 +562,7 @@ class ModelPatcher:
cls, cls,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
ti_list: List[Tuple[str, Any]], ti_list: List[Any],
) -> Tuple[CLIPTokenizer, TextualInversionManager]: ) -> Tuple[CLIPTokenizer, TextualInversionManager]:
init_tokens_count = None init_tokens_count = None
new_tokens_added = None new_tokens_added = None
@@ -174,27 +572,27 @@ class ModelPatcher:
ti_manager = TextualInversionManager(ti_tokenizer) ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
def _get_trigger(ti_name, index): def _get_trigger(ti, index):
trigger = ti_name trigger = ti.name
if index > 0: if index > 0:
trigger += f"-!pad-{i}" trigger += f"-!pad-{i}"
return f"<{trigger}>" return f"<{trigger}>"
# modify tokenizer # modify tokenizer
new_tokens_added = 0 new_tokens_added = 0
for ti_name, ti in ti_list: for ti in ti_list:
for i in range(ti.embedding.shape[0]): for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder # modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added) text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
model_embeddings = text_encoder.get_input_embeddings() model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list: for ti in ti_list:
ti_tokens = [] ti_tokens = []
for i in range(ti.embedding.shape[0]): for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i] embedding = ti.embedding[i]
trigger = _get_trigger(ti_name, i) trigger = _get_trigger(ti, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger) token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id: if token_id == ti_tokenizer.unk_token_id:
@@ -239,6 +637,7 @@ class ModelPatcher:
class TextualInversionModel: class TextualInversionModel:
name: str
embedding: torch.Tensor # [n, 768]|[n, 1280] embedding: torch.Tensor # [n, 768]|[n, 1280]
@classmethod @classmethod
@@ -252,6 +651,10 @@ class TextualInversionModel:
file_path = Path(file_path) file_path = Path(file_path)
result = cls() # TODO: result = cls() # TODO:
if file_path.name == "learned_embeds.bin":
result.name = file_path.parent.name
else:
result.name = file_path.stem
if file_path.suffix == ".safetensors": if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu") state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
@@ -361,8 +764,7 @@ class ONNXModelPatcher:
layer.to(dtype=torch.float32) layer.to(dtype=torch.float32)
layer_key = layer_key.replace(prefix, "") layer_key = layer_key.replace(prefix, "")
# TODO: rewrite to pass original tensor weight(required by ia3) layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
if layer_key is blended_loras: if layer_key is blended_loras:
blended_loras[layer_key] += layer_weight blended_loras[layer_key] += layer_weight
else: else:
@@ -429,7 +831,7 @@ class ONNXModelPatcher:
cls, cls,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel, text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Tuple[str, Any]], ti_list: List[Any],
) -> Tuple[CLIPTokenizer, TextualInversionManager]: ) -> Tuple[CLIPTokenizer, TextualInversionManager]:
from .models.base import IAIOnnxRuntimeModel from .models.base import IAIOnnxRuntimeModel
@@ -442,17 +844,17 @@ class ONNXModelPatcher:
ti_tokenizer = copy.deepcopy(tokenizer) ti_tokenizer = copy.deepcopy(tokenizer)
ti_manager = TextualInversionManager(ti_tokenizer) ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti_name, index): def _get_trigger(ti, index):
trigger = ti_name trigger = ti.name
if index > 0: if index > 0:
trigger += f"-!pad-{i}" trigger += f"-!pad-{i}"
return f"<{trigger}>" return f"<{trigger}>"
# modify tokenizer # modify tokenizer
new_tokens_added = 0 new_tokens_added = 0
for ti_name, ti in ti_list: for ti in ti_list:
for i in range(ti.embedding.shape[0]): for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder # modify text_encoder
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
@@ -462,11 +864,11 @@ class ONNXModelPatcher:
axis=0, axis=0,
) )
for ti_name, ti in ti_list: for ti in ti_list:
ti_tokens = [] ti_tokens = []
for i in range(ti.embedding.shape[0]): for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i].detach().numpy() embedding = ti.embedding[i].detach().numpy()
trigger = _get_trigger(ti_name, i) trigger = _get_trigger(ti, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger) token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id: if token_id == ti_tokenizer.unk_token_id:

View File

@@ -28,6 +28,8 @@ import torch
import logging import logging
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel
from .models import BaseModelType, ModelType, SubModelType, ModelBase from .models import BaseModelType, ModelType, SubModelType, ModelBase
# Maximum size of the cache, in gigs # Maximum size of the cache, in gigs

View File

@@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR.
""" """
from __future__ import annotations from __future__ import annotations
import hashlib
import os import os
import hashlib
import textwrap import textwrap
import types import yaml
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree, move from shutil import rmtree, move
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
import torch import torch
import yaml
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@@ -259,7 +259,6 @@ from .models import (
ModelNotFoundException, ModelNotFoundException,
InvalidModelException, InvalidModelException,
DuplicateModelException, DuplicateModelException,
ModelBase,
) )
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
@@ -362,7 +361,7 @@ class ModelManager(object):
if model_key.startswith("_"): if model_key.startswith("_"):
continue continue
model_name, base_model, model_type = self.parse_key(model_key) model_name, base_model, model_type = self.parse_key(model_key)
model_class = self._get_implementation(base_model, model_type) model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file # alias for config file
model_config["model_format"] = model_config.pop("format") model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config) self.models[model_key] = model_class.create_config(**model_config)
@@ -382,24 +381,18 @@ class ModelManager(object):
# causing otherwise unreferenced models to be removed from memory # causing otherwise unreferenced models to be removed from memory
self._read_models() self._read_models()
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool: def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
""" """
Given a model name, returns True if it is a valid identifier. Given a model name, returns True if it is a valid
identifier.
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param rescan: if True, scan_models_directory
""" """
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
exists = model_key in self.models return model_key in self.models
# if model not found try to find it (maybe file just pasted)
if rescan and not exists:
self.scan_models_directory(base_model=base_model, model_type=model_type)
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
return exists
@classmethod @classmethod
def create_key( def create_key(
@@ -450,32 +443,39 @@ class ModelManager(object):
:param model_name: symbolic name of the model in models.yaml :param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return :param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model :param base_model: BaseModelType enum indicating the base model used by this model
:param submodel_type: an ModelType enum indicating the portion of :param submode_typel: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae) the model to retrieve (e.g. ModelType.Vae)
""" """
model_class = MODEL_CLASSES[base_model][model_type]
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
if not self.model_exists(model_name, base_model, model_type, rescan=True): # if model not found try to find it (maybe file just pasted)
if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models:
raise ModelNotFoundException(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self._get_model_config(base_model, model_name, model_type) model_config = self.models[model_key]
model_path = self.resolve_model_path(model_config.path)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
model_type = submodel_type
submodel_type = None
model_class = self._get_implementation(base_model, model_type)
if not model_path.exists(): if not model_path.exists():
if model_class.save_to_config: if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound self.models[model_key].error = ModelError.NotFound
raise Exception(f'Files for model "{model_key}" not found at {model_path}') raise Exception(f'Files for model "{model_key}" not found')
else: else:
self.models.pop(model_key, None) self.models.pop(model_key, None)
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}') raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override
# TODO:
if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type)
if override_path:
model_path = self.resolve_path(override_path)
model_type = submodel_type
submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type]
# TODO: path # TODO: path
# TODO: is it accurate to use path as id # TODO: is it accurate to use path as id
@@ -513,61 +513,12 @@ class ModelManager(object):
_cache=self.cache, _cache=self.cache,
) )
def _get_model_path(
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
) -> (Path, bool):
"""Extract a model's filesystem path from its config.
:return: The fully qualified Path of the module (or submodule).
"""
model_path = model_config.path
is_submodel_override = False
# Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None and len(submodel_path) > 0:
model_path = getattr(model_config, submodel_type)
is_submodel_override = True
model_path = self.resolve_model_path(model_path)
return model_path, is_submodel_override
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
"""Get a model's config object."""
model_key = self.create_key(model_name, base_model, model_type)
try:
model_config = self.models[model_key]
except KeyError:
raise ModelNotFoundException(f"Model not found - {model_key}")
return model_config
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
"""Get the concrete implementation class for a specific model type."""
model_class = MODEL_CLASSES[base_model][model_type]
return model_class
def _instantiate(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
) -> ModelBase:
"""Make a new instance of this model, without loading it."""
model_config = self._get_model_config(base_model, model_name, model_type)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
# FIXME: do non-overriden submodels get the right class?
constructor = self._get_implementation(base_model, model_type)
instance = constructor(model_path, base_model, model_type)
return instance
def model_info( def model_info(
self, self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
) -> Union[dict, None]: ) -> dict:
""" """
Given a model name returns the OmegaConf (dict-like) object describing it. Given a model name returns the OmegaConf (dict-like) object describing it.
""" """
@@ -589,16 +540,13 @@ class ModelManager(object):
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
) -> Union[dict, None]: ) -> dict:
""" """
Returns a dict describing one installed model, using Returns a dict describing one installed model, using
the combined format of the list_models() method. the combined format of the list_models() method.
""" """
models = self.list_models(base_model, model_type, model_name) models = self.list_models(base_model, model_type, model_name)
if len(models) >= 1: return models[0] if models else None
return models[0]
else:
return None
def list_models( def list_models(
self, self,
@@ -612,7 +560,7 @@ class ModelManager(object):
model_keys = ( model_keys = (
[self.create_key(model_name, base_model, model_type)] [self.create_key(model_name, base_model, model_type)]
if model_name and base_model and model_type if model_name
else sorted(self.models, key=str.casefold) else sorted(self.models, key=str.casefold)
) )
models = [] models = []
@@ -648,7 +596,7 @@ class ModelManager(object):
Print a table of models and their descriptions. This needs to be redone Print a table of models and their descriptions. This needs to be redone
""" """
# TODO: redo # TODO: redo
for model_dict in self.list_models(): for model_type, model_dict in self.list_models().items():
for model_name, model_info in model_dict.items(): for model_name, model_info in model_dict.items():
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}' line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
print(line) print(line)
@@ -710,7 +658,7 @@ class ModelManager(object):
if path := model_attributes.get("path"): if path := model_attributes.get("path"):
model_attributes["path"] = str(self.relative_model_path(Path(path))) model_attributes["path"] = str(self.relative_model_path(Path(path)))
model_class = self._get_implementation(base_model, model_type) model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.create_config(**model_attributes) model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
@@ -751,8 +699,8 @@ class ModelManager(object):
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
new_name: Optional[str] = None, new_name: str = None,
new_base: Optional[BaseModelType] = None, new_base: BaseModelType = None,
): ):
""" """
Rename or rebase a model. Rename or rebase a model.
@@ -805,7 +753,7 @@ class ModelManager(object):
self, self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae], model_type: Union[ModelType.Main, ModelType.Vae],
dest_directory: Optional[Path] = None, dest_directory: Optional[Path] = None,
) -> AddModelResult: ) -> AddModelResult:
""" """
@@ -819,10 +767,6 @@ class ModelManager(object):
This will raise a ValueError unless the model is a checkpoint. This will raise a ValueError unless the model is a checkpoint.
""" """
info = self.model_info(model_name, base_model, model_type) info = self.model_info(model_name, base_model, model_type)
if info is None:
raise FileNotFoundError(f"model not found: {model_name}")
if info["model_format"] != "checkpoint": if info["model_format"] != "checkpoint":
raise ValueError(f"not a checkpoint format model: {model_name}") raise ValueError(f"not a checkpoint format model: {model_name}")
@@ -892,7 +836,7 @@ class ModelManager(object):
return search_folder, found_models return search_folder, found_models
def commit(self, conf_file: Optional[Path] = None) -> None: def commit(self, conf_file: Path = None) -> None:
""" """
Write current configuration out to the indicated file. Write current configuration out to the indicated file.
""" """
@@ -901,7 +845,7 @@ class ModelManager(object):
for model_key, model_config in self.models.items(): for model_key, model_config in self.models.items():
model_name, base_model, model_type = self.parse_key(model_key) model_name, base_model, model_type = self.parse_key(model_key)
model_class = self._get_implementation(base_model, model_type) model_class = MODEL_CLASSES[base_model][model_type]
if model_class.save_to_config: if model_class.save_to_config:
# TODO: or exclude_unset better fits here? # TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"}) data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
@@ -959,7 +903,7 @@ class ModelManager(object):
model_path = self.resolve_model_path(model_config.path).absolute() model_path = self.resolve_model_path(model_config.path).absolute()
if not model_path.exists(): if not model_path.exists():
model_class = self._get_implementation(cur_base_model, cur_model_type) model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
if model_class.save_to_config: if model_class.save_to_config:
model_config.error = ModelError.NotFound model_config.error = ModelError.NotFound
self.models.pop(model_key, None) self.models.pop(model_key, None)
@@ -975,7 +919,7 @@ class ModelManager(object):
for cur_model_type in ModelType: for cur_model_type in ModelType:
if model_type is not None and cur_model_type != model_type: if model_type is not None and cur_model_type != model_type:
continue continue
model_class = self._get_implementation(cur_base_model, cur_model_type) model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value)) models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
if not models_dir.exists(): if not models_dir.exists():
@@ -991,9 +935,7 @@ class ModelManager(object):
raise DuplicateModelException(f"Model with key {model_key} added twice") raise DuplicateModelException(f"Model with key {model_key} added twice")
model_path = self.relative_model_path(model_path) model_path = self.relative_model_path(model_path)
model_config: ModelConfigBase = model_class.probe_config( model_config: ModelConfigBase = model_class.probe_config(str(model_path))
str(model_path), model_base=cur_base_model
)
self.models[model_key] = model_config self.models[model_key] = model_config
new_models_found = True new_models_found = True
except DuplicateModelException as e: except DuplicateModelException as e:
@@ -1041,7 +983,7 @@ class ModelManager(object):
# LS: hacky # LS: hacky
# Patch in the SD VAE from core so that it is available for use by the UI # Patch in the SD VAE from core so that it is available for use by the UI
try: try:
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))}) self.heuristic_import({self.resolve_model_path("core/convert/sd-vae-ft-mse")})
except: except:
pass pass
@@ -1069,7 +1011,7 @@ class ModelManager(object):
def heuristic_import( def heuristic_import(
self, self,
items_to_import: Set[str], items_to_import: Set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
) -> Dict[str, AddModelResult]: ) -> Dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of """Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.

View File

@@ -33,7 +33,7 @@ class ModelMerger(object):
self, self,
model_paths: List[Path], model_paths: List[Path],
alpha: float = 0.5, alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None, interp: MergeInterpolationMethod = None,
force: bool = False, force: bool = False,
**kwargs, **kwargs,
) -> DiffusionPipeline: ) -> DiffusionPipeline:
@@ -73,7 +73,7 @@ class ModelMerger(object):
base_model: Union[BaseModelType, str], base_model: Union[BaseModelType, str],
merged_model_name: str, merged_model_name: str,
alpha: float = 0.5, alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None, interp: MergeInterpolationMethod = None,
force: bool = False, force: bool = False,
merge_dest_directory: Optional[Path] = None, merge_dest_directory: Optional[Path] = None,
**kwargs, **kwargs,
@@ -122,7 +122,7 @@ class ModelMerger(object):
dump_path.mkdir(parents=True, exist_ok=True) dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=True) merged_pipe.save_pretrained(dump_path, safe_serialization=1)
attributes = dict( attributes = dict(
path=str(dump_path), path=str(dump_path),
description=f"Merge of models {', '.join(model_names)}", description=f"Merge of models {', '.join(model_names)}",

View File

@@ -17,7 +17,6 @@ from .models import (
SilenceWarnings, SilenceWarnings,
InvalidModelException, InvalidModelException,
) )
from .util import lora_token_vector_length
from .models.base import read_checkpoint_meta from .models.base import read_checkpoint_meta
@@ -316,16 +315,21 @@ class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint) key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
if token_vector_length == 768: lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[0]
if key2 in checkpoint
else 768
)
if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
elif token_vector_length == 1024: elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2 return BaseModelType.StableDiffusion2
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else: else:
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}") return None
class TextualInversionCheckpointProbe(CheckpointProbeBase): class TextualInversionCheckpointProbe(CheckpointProbeBase):

View File

@@ -292,9 +292,8 @@ class DiffusersModel(ModelBase):
) )
break break
except Exception as e: except Exception as e:
if not str(e).startswith("Error no file"): # print("====ERR LOAD====")
print("====ERR LOAD====") # print(f"{variant}: {e}")
print(f"{variant}: {e}")
pass pass
else: else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model") raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")

View File

@@ -1,9 +1,7 @@
import os import os
import torch import torch
from enum import Enum from enum import Enum
from typing import Optional, Dict, Union, Literal, Any from typing import Optional, Union, Literal
from pathlib import Path
from safetensors.torch import load_file
from .base import ( from .base import (
ModelBase, ModelBase,
ModelConfigBase, ModelConfigBase,
@@ -15,6 +13,9 @@ from .base import (
ModelNotFoundException, ModelNotFoundException,
) )
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum): class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris" LyCORIS = "lycoris"
@@ -49,7 +50,6 @@ class LoRAModel(ModelBase):
model = LoRAModelRaw.from_checkpoint( model = LoRAModelRaw.from_checkpoint(
file_path=self.model_path, file_path=self.model_path,
dtype=torch_dtype, dtype=torch_dtype,
base_model=self.base_model,
) )
self.model_size = model.calc_size() self.model_size = model.calc_size()
@@ -87,591 +87,3 @@ class LoRAModel(ModelBase):
raise NotImplementedError("Diffusers lora not supported") raise NotImplementedError("Diffusers lora not supported")
else: else:
return model_path return model_path
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: dict,
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: torch.Tensor):
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self, orig_weight: torch.Tensor):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1 = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2 = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self, orig_weight: torch.Tensor):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2 = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2 = values["lokr_t2"]
else:
self.t2 = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class FullLayer(LoRALayerBase):
# weight: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.weight = values["diff"]
if len(values.keys()) > 1:
_keys = list(values.keys())
_keys.remove("diff")
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def _convert_sdxl_compvis_keys(cls, state_dict):
new_state_dict = dict()
for full_key, value in state_dict.items():
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
continue # clip same
if not full_key.startswith("lora_unet_"):
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
src_key = full_key.replace("lora_unet_", "")
try:
dst_key = None
while "_" in src_key:
if src_key in SDXL_UNET_COMPVIS_MAP:
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
break
src_key = "_".join(src_key.split("_")[:-1])
if dst_key is None:
raise Exception(f"Unknown sdxl lora key - {full_key}")
new_key = full_key.replace(src_key, dst_key)
except:
print(SDXL_UNET_COMPVIS_MAP)
raise
new_state_dict[new_key] = value
return new_state_dict
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_compvis_keys(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "weight" in values and "on_input" in values:
layer = IA3Layer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map():
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_COMPVIS_MAP = {
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
for sd, hf in make_sdxl_unet_conversion_map()
}

View File

@@ -80,10 +80,8 @@ class StableDiffusionXLModel(DiffusersModel):
raise Exception("Unkown stable diffusion 2.* model format") raise Exception("Unkown stable diffusion 2.* model format")
if ckpt_config_path is None: if ckpt_config_path is None:
# avoid circular import # TO DO: implement picking
from .stable_diffusion import _select_ckpt_config pass
ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant)
return cls.create_config( return cls.create_config(
path=path, path=path,

View File

@@ -4,7 +4,6 @@ from enum import Enum
from pydantic import Field from pydantic import Field
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from .base import ( from .base import (
ModelConfigBase, ModelConfigBase,
BaseModelType, BaseModelType,
@@ -264,8 +263,6 @@ def _convert_ckpt_and_cache(
weights = app_config.models_path / model_config.path weights = app_config.models_path / model_config.path
config_file = app_config.root_path / model_config.config config_file = app_config.root_path / model_config.config
output_path = Path(output_path) output_path = Path(output_path)
variant = model_config.variant
pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline
# return cached version if it exists # return cached version if it exists
if output_path.exists(): if output_path.exists():
@@ -292,7 +289,6 @@ def _convert_ckpt_and_cache(
original_config_file=config_file, original_config_file=config_file,
extract_ema=True, extract_ema=True,
scan_needed=True, scan_needed=True,
pipeline_class=pipeline_class,
from_safetensors=weights.suffix == ".safetensors", from_safetensors=weights.suffix == ".safetensors",
precision=torch_dtype(choose_torch_device()), precision=torch_dtype(choose_torch_device()),
**kwargs, **kwargs,

View File

@@ -1,14 +1,9 @@
import os import os
import torch
import safetensors
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Union, Literal
import safetensors
import torch
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from .base import ( from .base import (
ModelBase, ModelBase,
ModelConfigBase, ModelConfigBase,
@@ -23,6 +18,9 @@ from .base import (
InvalidModelException, InvalidModelException,
ModelNotFoundException, ModelNotFoundException,
) )
from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum): class VaeModelFormat(str, Enum):
@@ -82,7 +80,7 @@ class VaeModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if not os.path.exists(path): if not os.path.exists(path):
raise ModelNotFoundException(f"Does not exist as local file: {path}") raise ModelNotFoundException()
if os.path.isdir(path): if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")): if os.path.exists(os.path.join(path, "config.json")):

View File

@@ -1,75 +0,0 @@
# Copyright (c) 2023 The InvokeAI Development Team
"""Utilities used by the Model Manager"""
def lora_token_vector_length(checkpoint: dict) -> int:
"""
Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint
"""
def _get_shape_1(key, tensor, checkpoint):
lora_token_vector_length = None
if "." not in key:
return lora_token_vector_length # wrong key format
model_key, lora_key = key.split(".", 1)
# check lora/locon
if lora_key == "lora_down.weight":
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif "lokr_" in lora_key:
if model_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
elif model_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format
if model_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
elif model_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif lora_key == "diff":
lora_token_vector_length = tensor.shape[1]
# ia3 can be detected only by shape[0] in text encoder
elif lora_key == "weight" and "lora_unet_" not in model_key:
lora_token_vector_length = tensor.shape[0]
return lora_token_vector_length
lora_token_vector_length = None
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):
lora_token_vector_length = tmp_length
elif key.startswith("lora_te1_"):
lora_te1_length = tmp_length
elif key.startswith("lora_te2_"):
lora_te2_length = tmp_length
if lora_te1_length is not None and lora_te2_length is not None:
lora_token_vector_length = lora_te1_length + lora_te2_length
if lora_token_vector_length is not None:
break
return lora_token_vector_length

View File

@@ -4,21 +4,25 @@ import dataclasses
import inspect import inspect
import math import math
import secrets import secrets
from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import Field
import PIL.Image
import einops import einops
import PIL.Image
import numpy as np
from accelerate.utils import set_seed
import psutil import psutil
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from accelerate.utils import set_seed
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
) )
@@ -27,20 +31,21 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
) )
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils import PIL_INTERPOLATION
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput from diffusers.utils.outputs import BaseOutput
from pydantic import Field
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from ..util import CPU_DEVICE, normalize_device
from .diffusion import ( from .diffusion import (
AttentionMapSaver, AttentionMapSaver,
InvokeAIDiffuserComponent, InvokeAIDiffuserComponent,
PostprocessingSettings, PostprocessingSettings,
) )
from ..util import normalize_device from .offloading import FullyLoadedModelGroup, ModelGroup
@dataclass @dataclass
@@ -284,6 +289,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_model_group: ModelGroup
ID_LENGTH = 8 ID_LENGTH = 8
def __init__( def __init__(
@@ -296,7 +303,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker: Optional[StableDiffusionSafetyChecker], safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor], feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False, requires_safety_checker: bool = False,
precision: str = "float32",
control_model: ControlNetModel = None, control_model: ControlNetModel = None,
execution_device: Optional[torch.device] = None,
): ):
super().__init__( super().__init__(
vae, vae,
@@ -321,6 +330,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# control_model=control_model, # control_model=control_model,
) )
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
self._model_group.install(*self._submodels)
self.control_model = control_model self.control_model = control_model
def _adjust_memory_efficient_attention(self, latents: torch.Tensor): def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
@@ -356,6 +368,72 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
else: else:
self.disable_attention_slicing() self.disable_attention_slicing()
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
# overridden method; types match the superclass.
if torch_device is None:
return self
self._model_group.set_device(torch.device(torch_device))
self._model_group.ready()
@property
def device(self) -> torch.device:
return self._model_group.execution_device
@property
def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config))
submodels = []
for name in module_names.keys():
if hasattr(self, name):
value = getattr(self, name)
else:
value = getattr(self.config, name)
if isinstance(value, torch.nn.Module):
submodels.append(value)
return submodels
def image_from_embeddings(
self,
latents: torch.Tensor,
num_inference_steps: int,
conditioning_data: ConditioningData,
*,
noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
) -> InvokeAIStableDiffusionPipelineOutput:
r"""
Function invoked when calling the pipeline for generation.
:param conditioning_data:
:param latents: Pre-generated un-noised latents, to be used as inputs for
image generation. Can be used to tweak the same generation with different prompts.
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
image at the expense of slower inference.
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
:param callback:
:param run_id:
"""
result_latents, result_attention_map_saver = self.latents_from_embeddings(
latents,
num_inference_steps,
conditioning_data,
noise=noise,
run_id=run_id,
callback=callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_map_saver,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def latents_from_embeddings( def latents_from_embeddings(
self, self,
latents: torch.Tensor, latents: torch.Tensor,
@@ -372,7 +450,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.scheduler.config.get("cpu_only", False): if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device("cpu") scheduler_device = torch.device("cpu")
else: else:
scheduler_device = self.unet.device scheduler_device = self._model_group.device_for(self.unet)
if timesteps is None: if timesteps is None:
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device) self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
@@ -426,7 +504,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
(batch_size,), (batch_size,),
timesteps[0], timesteps[0],
dtype=timesteps.dtype, dtype=timesteps.dtype,
device=self.unet.device, device=self._model_group.device_for(self.unet),
) )
latents = self.scheduler.add_noise(latents, noise, batched_t) latents = self.scheduler.add_noise(latents, noise, batched_t)
@@ -622,6 +700,79 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
**kwargs, **kwargs,
).sample ).sample
def img2img_from_embeddings(
self,
init_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float,
num_inference_steps: int,
conditioning_data: ConditioningData,
*,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
noise_func=None,
seed=None,
) -> InvokeAIStableDiffusionPipelineOutput:
if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
if init_image.dim() == 3:
init_image = einops.rearrange(init_image, "c h w -> 1 c h w")
# 6. Prepare latent variables
initial_latents = self.non_noised_latents_from_image(
init_image,
device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype,
)
if seed is not None:
set_seed(seed)
noise = noise_func(initial_latents)
return self.img2img_from_latents_and_embeddings(
initial_latents,
num_inference_steps,
conditioning_data,
strength,
noise,
run_id,
callback,
)
def img2img_from_latents_and_embeddings(
self,
initial_latents,
num_inference_steps,
conditioning_data: ConditioningData,
strength,
noise: torch.Tensor,
run_id=None,
callback=None,
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents
if strength < 1.0
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
timesteps=timesteps,
noise=noise,
run_id=run_id,
callback=callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int): def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler assert img2img_pipeline.scheduler is self.scheduler
@@ -629,7 +780,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.scheduler.config.get("cpu_only", False): if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device("cpu") scheduler_device = torch.device("cpu")
else: else:
scheduler_device = self.unet.device scheduler_device = self._model_group.device_for(self.unet)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device) img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
timesteps, adjusted_steps = img2img_pipeline.get_timesteps( timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
@@ -655,7 +806,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise_func=None, noise_func=None,
seed=None, seed=None,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device device = self._model_group.device_for(self.unet)
latents_dtype = self.unet.dtype latents_dtype = self.unet.dtype
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
@@ -726,17 +877,42 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
nsfw_content_detected=[], nsfw_content_detected=[],
attention_map_saver=result_attention_maps, attention_map_saver=result_attention_maps,
) )
return output return self.check_for_safety(output, dtype=conditioning_data.dtype)
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype): def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
init_image = init_image.to(device=device, dtype=dtype) init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode(): with torch.inference_mode():
self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
init_latents = 0.18215 * init_latents init_latents = 0.18215 * init_latents
return init_latents return init_latents
def check_for_safety(self, output, dtype):
with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver
return InvokeAIStableDiffusionPipelineOutput(
screened_images,
has_nsfw_concept,
# block the attention maps if NSFW content is detected
attention_map_saver=screened_attention_map_saver,
)
def run_safety_checker(self, image, device=None, dtype=None):
# overriding to use the model group for device info instead of requiring the caller to know.
if self.safety_checker is not None:
device = self._model_group.device_for(self.safety_checker)
return super().run_safety_checker(image, device, dtype)
def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
self._model_group.load(self.vae)
return super().decode_latents(latents)
def debug_latents(self, latents, msg): def debug_latents(self, latents, msg):
from invokeai.backend.image_util import debug_image from invokeai.backend.image_util import debug_image

View File

@@ -78,9 +78,10 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None self.cross_attention_control_context = None
self.sequential_guidance = config.sequential_guidance self.sequential_guidance = config.sequential_guidance
@classmethod
@contextmanager @contextmanager
def custom_attention_context( def custom_attention_context(
self, cls,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo], extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int, step_count: int,
@@ -90,19 +91,18 @@ class InvokeAIDiffuserComponent:
old_attn_processors = unet.attn_processors old_attn_processors = unet.attn_processors
# Load lora conditions into the model # Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control: if extra_conditioning_info.wants_cross_attention_control:
self.cross_attention_control_context = Context( cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args, arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count, step_count=step_count,
) )
setup_cross_attention_control_attention_processors( setup_cross_attention_control_attention_processors(
unet, unet,
self.cross_attention_control_context, cross_attention_control_context,
) )
try: try:
yield None yield None
finally: finally:
self.cross_attention_control_context = None
if old_attn_processors is not None: if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors) unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving # TODO resuscitate attention map saving

View File

@@ -0,0 +1,253 @@
from __future__ import annotations
import warnings
import weakref
from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping
from typing import Callable, Union
import torch
from accelerate.utils import send_to_device
from torch.utils.hooks import RemovableHandle
OFFLOAD_DEVICE = torch.device("cpu")
class _NoModel:
"""Symbol that indicates no model is loaded.
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
type-checkable.)
"""
def __bool__(self):
return False
def to(self, device: torch.device):
pass
def __repr__(self):
return "<NO MODEL>"
NO_MODEL = _NoModel()
class ModelGroup(metaclass=ABCMeta):
"""
A group of models.
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
e.g. its text encoder, U-net, VAE, etc.
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
:py:class:`torch.nn.Module` here.
"""
def __init__(self, execution_device: torch.device):
self.execution_device = execution_device
@abstractmethod
def install(self, *models: torch.nn.Module):
"""Add models to this group."""
pass
@abstractmethod
def uninstall(self, models: torch.nn.Module):
"""Remove models from this group."""
pass
@abstractmethod
def uninstall_all(self):
"""Remove all models from this group."""
@abstractmethod
def load(self, model: torch.nn.Module):
"""Load this model to the execution device."""
pass
@abstractmethod
def offload_current(self):
"""Offload the current model(s) from the execution device."""
pass
@abstractmethod
def ready(self):
"""Ready this group for use."""
pass
@abstractmethod
def set_device(self, device: torch.device):
"""Change which device models from this group will execute on."""
pass
@abstractmethod
def device_for(self, model) -> torch.device:
"""Get the device the given model will execute on.
The model should already be a member of this group.
"""
pass
@abstractmethod
def __contains__(self, model):
"""Check if the model is a member of this group."""
pass
def __repr__(self) -> str:
return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >"
class LazilyLoadedModelGroup(ModelGroup):
"""
Only one model from this group is loaded on the GPU at a time.
Running the forward method of a model will displace the previously-loaded model,
offloading it to CPU.
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
you will need to explicitly load it with :py:method:`.load(model)`.
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
to the appropriate execution device, as long as they are positional arguments and not keyword
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
"""
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)
self._hooks = weakref.WeakKeyDictionary()
self._current_model_ref = weakref.ref(NO_MODEL)
def install(self, *models: torch.nn.Module):
for model in models:
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
def uninstall(self, *models: torch.nn.Module):
for model in models:
hook = self._hooks.pop(model)
hook.remove()
if self.is_current_model(model):
# no longer hooked by this object, so don't claim to manage it
self.clear_current_model()
def uninstall_all(self):
self.uninstall(*self._hooks.keys())
def _pre_hook(self, module: torch.nn.Module, forward_input):
self.load(module)
if len(forward_input) == 0:
warnings.warn(
f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.",
stacklevel=3,
)
return send_to_device(forward_input, self.execution_device)
def load(self, module):
if not self.is_current_model(module):
self.offload_current()
self._load(module)
def offload_current(self):
module = self._current_model_ref()
if module is not NO_MODEL:
module.to(OFFLOAD_DEVICE)
self.clear_current_model()
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
module = module.to(self.execution_device)
self.set_current_model(module)
return module
def is_current_model(self, model: torch.nn.Module) -> bool:
"""Is the given model the one currently loaded on the execution device?"""
return self._current_model_ref() is model
def is_empty(self):
"""Are none of this group's models loaded on the execution device?"""
return self._current_model_ref() is NO_MODEL
def set_current_model(self, value):
self._current_model_ref = weakref.ref(value)
def clear_current_model(self):
self._current_model_ref = weakref.ref(NO_MODEL)
def set_device(self, device: torch.device):
if device == self.execution_device:
return
self.execution_device = device
current = self._current_model_ref()
if current is not NO_MODEL:
current.to(device)
def device_for(self, model):
if model not in self:
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def ready(self):
pass # always ready to load on-demand
def __contains__(self, model):
return model in self._hooks
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} object at {id(self):x}: "
f"current_model={type(self._current_model_ref()).__name__} >"
)
class FullyLoadedModelGroup(ModelGroup):
"""
A group of models without any implicit loading or unloading.
:py:meth:`.ready` loads _all_ the models to the execution device at once.
"""
_models: weakref.WeakSet
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)
self._models = weakref.WeakSet()
def install(self, *models: torch.nn.Module):
for model in models:
self._models.add(model)
model.to(self.execution_device)
def uninstall(self, *models: torch.nn.Module):
for model in models:
self._models.remove(model)
def uninstall_all(self):
self.uninstall(*self._models)
def load(self, model):
model.to(self.execution_device)
def offload_current(self):
for model in self._models:
model.to(OFFLOAD_DEVICE)
def ready(self):
for model in self._models:
self.load(model)
def set_device(self, device: torch.device):
self.execution_device = device
for model in self._models:
if model.device != OFFLOAD_DEVICE:
model.to(device)
def device_for(self, model):
if model not in self:
raise KeyError("This does not manage this model f{type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def __contains__(self, model):
return model in self._models

View File

@@ -1,8 +1,6 @@
from __future__ import annotations from __future__ import annotations
from contextlib import nullcontext from contextlib import nullcontext
from packaging import version
import platform
import torch import torch
from torch import autocast from torch import autocast
@@ -32,7 +30,7 @@ def choose_precision(device: torch.device) -> str:
device_name = torch.cuda.get_device_name(device) device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
return "float16" return "float16"
elif device.type == "mps" and version.parse(platform.mac_ver()[0]) < version.parse("14.0.0"): elif device.type == "mps":
return "float16" return "float16"
return "float32" return "float32"

View File

@@ -1,3 +1,6 @@
""" """
Initialization file for invokeai.frontend.config Initialization file for invokeai.frontend.config
""" """
from .invokeai_configure import main as invokeai_configure
from .invokeai_update import main as invokeai_update
from .model_install import main as invokeai_model_install

View File

@@ -1,795 +0,0 @@
# Copyright (c) 2023 - The InvokeAI Team
# Primary Author: David Lovell (github @f412design, discord @techjedi)
# co-author, minor tweaks - Lincoln Stein
# pylint: disable=line-too-long
# pylint: disable=broad-exception-caught
"""Script to import images into the new database system for 3.0.0"""
import os
import datetime
import shutil
import locale
import sqlite3
import json
import glob
import re
import uuid
import yaml
import PIL
import PIL.ImageOps
import PIL.PngImagePlugin
from pathlib import Path
from prompt_toolkit import prompt
from prompt_toolkit.shortcuts import message_dialog
from prompt_toolkit.completion import PathCompleter
from prompt_toolkit.key_binding import KeyBindings
from invokeai.app.services.config import InvokeAIAppConfig
app_config = InvokeAIAppConfig.get_config()
bindings = KeyBindings()
@bindings.add("c-c")
def _(event):
raise KeyboardInterrupt
# release notes
# "Use All" with size dimensions not selectable in the UI will not load dimensions
class Config:
"""Configuration loader."""
def __init__(self):
pass
TIMESTAMP_STRING = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
INVOKE_DIRNAME = "invokeai"
YAML_FILENAME = "invokeai.yaml"
DATABASE_FILENAME = "invokeai.db"
database_path = None
database_backup_dir = None
outputs_path = None
thumbnail_path = None
def find_and_load(self):
"""find the yaml config file and load"""
root = app_config.root_path
if not self.confirm_and_load(os.path.abspath(root)):
print("\r\nSpecify custom database and outputs paths:")
self.confirm_and_load_from_user()
self.database_backup_dir = os.path.join(os.path.dirname(self.database_path), "backup")
self.thumbnail_path = os.path.join(self.outputs_path, "thumbnails")
def confirm_and_load(self, invoke_root):
"""Validates a yaml path exists, confirms the user wants to use it and loads config."""
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
if os.path.exists(yaml_path):
db_dir, outdir = self.load_paths_from_yaml(yaml_path)
if os.path.isabs(db_dir):
database_path = os.path.join(db_dir, self.DATABASE_FILENAME)
else:
database_path = os.path.join(invoke_root, db_dir, self.DATABASE_FILENAME)
if os.path.isabs(outdir):
outputs_path = os.path.join(outdir, "images")
else:
outputs_path = os.path.join(invoke_root, outdir, "images")
db_exists = os.path.exists(database_path)
outdir_exists = os.path.exists(outputs_path)
text = f"Found {self.YAML_FILENAME} file at {yaml_path}:"
text += f"\n Database : {database_path}"
text += f"\n Outputs : {outputs_path}"
text += "\n\nUse these paths for import (yes) or choose different ones (no) [Yn]: "
if db_exists and outdir_exists:
if (prompt(text).strip() or "Y").upper().startswith("Y"):
self.database_path = database_path
self.outputs_path = outputs_path
return True
else:
return False
else:
print(" Invalid: One or more paths in this config did not exist and cannot be used.")
else:
message_dialog(
title="Path not found",
text=f"Auto-discovery of configuration failed! Could not find ({yaml_path}), Custom paths can be specified.",
).run()
return False
def confirm_and_load_from_user(self):
default = ""
while True:
database_path = os.path.expanduser(
prompt(
"Database: Specify absolute path to the database to import into: ",
completer=PathCompleter(
expanduser=True, file_filter=lambda x: Path(x).is_dir() or x.endswith((".db"))
),
default=default,
)
)
if database_path.endswith(".db") and os.path.isabs(database_path) and os.path.exists(database_path):
break
default = database_path + "/" if Path(database_path).is_dir() else database_path
default = ""
while True:
outputs_path = os.path.expanduser(
prompt(
"Outputs: Specify absolute path to outputs/images directory to import into: ",
completer=PathCompleter(expanduser=True, only_directories=True),
default=default,
)
)
if outputs_path.endswith("images") and os.path.isabs(outputs_path) and os.path.exists(outputs_path):
break
default = outputs_path + "/" if Path(outputs_path).is_dir() else outputs_path
self.database_path = database_path
self.outputs_path = outputs_path
return
def load_paths_from_yaml(self, yaml_path):
"""Load an Invoke AI yaml file and get the database and outputs paths."""
try:
with open(yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
yamlinfo = yaml.safe_load(file)
db_dir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("db_dir", None)
outdir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("outdir", None)
return db_dir, outdir
except Exception:
print(f"Failed to load paths from yaml file! {yaml_path}!")
return None, None
class ImportStats:
"""DTO for tracking work progress."""
def __init__(self):
pass
time_start = datetime.datetime.utcnow()
count_source_files = 0
count_skipped_file_exists = 0
count_skipped_db_exists = 0
count_imported = 0
count_imported_by_version = {}
count_file_errors = 0
@staticmethod
def get_elapsed_time_string():
"""Get a friendly time string for the time elapsed since processing start."""
time_now = datetime.datetime.utcnow()
total_seconds = (time_now - ImportStats.time_start).total_seconds()
hours = int((total_seconds) / 3600)
minutes = int(((total_seconds) % 3600) / 60)
seconds = total_seconds % 60
out_str = f"{hours} hour(s) -" if hours > 0 else ""
out_str += f"{minutes} minute(s) -" if minutes > 0 else ""
out_str += f"{seconds:.2f} second(s)"
return out_str
class InvokeAIMetadata:
"""DTO for core Invoke AI generation properties parsed from metadata."""
def __init__(self):
pass
def __str__(self):
formatted_str = f"{self.generation_mode}~{self.steps}~{self.cfg_scale}~{self.model_name}~{self.scheduler}~{self.seed}~{self.width}~{self.height}~{self.rand_device}~{self.strength}~{self.init_image}"
formatted_str += f"\r\npositive_prompt: {self.positive_prompt}"
formatted_str += f"\r\nnegative_prompt: {self.negative_prompt}"
return formatted_str
generation_mode = None
steps = None
cfg_scale = None
model_name = None
scheduler = None
seed = None
width = None
height = None
rand_device = None
strength = None
init_image = None
positive_prompt = None
negative_prompt = None
imported_app_version = None
def to_json(self):
"""Convert the active instance to json format."""
prop_dict = {}
prop_dict["generation_mode"] = self.generation_mode
# dont render prompt nodes if neither are set to avoid the ui thinking it can set them
# if at least one exists, render them both, but use empty string instead of None if one of them is empty
# this allows the field that is empty to actually be cleared byt he UI instead of leaving the previous value
if self.positive_prompt or self.negative_prompt:
prop_dict["positive_prompt"] = "" if self.positive_prompt is None else self.positive_prompt
prop_dict["negative_prompt"] = "" if self.negative_prompt is None else self.negative_prompt
prop_dict["width"] = self.width
prop_dict["height"] = self.height
# only render seed if it has a value to avoid ui thinking it can set this and then error
if self.seed:
prop_dict["seed"] = self.seed
prop_dict["rand_device"] = self.rand_device
prop_dict["cfg_scale"] = self.cfg_scale
prop_dict["steps"] = self.steps
prop_dict["scheduler"] = self.scheduler
prop_dict["clip_skip"] = 0
prop_dict["model"] = {}
prop_dict["model"]["model_name"] = self.model_name
prop_dict["model"]["base_model"] = None
prop_dict["controlnets"] = []
prop_dict["loras"] = []
prop_dict["vae"] = None
prop_dict["strength"] = self.strength
prop_dict["init_image"] = self.init_image
prop_dict["positive_style_prompt"] = None
prop_dict["negative_style_prompt"] = None
prop_dict["refiner_model"] = None
prop_dict["refiner_cfg_scale"] = None
prop_dict["refiner_steps"] = None
prop_dict["refiner_scheduler"] = None
prop_dict["refiner_aesthetic_store"] = None
prop_dict["refiner_start"] = None
prop_dict["imported_app_version"] = self.imported_app_version
return json.dumps(prop_dict)
class InvokeAIMetadataParser:
"""Parses strings with json data to find Invoke AI core metadata properties."""
def __init__(self):
pass
def parse_meta_tag_dream(self, dream_string):
"""Take as input an png metadata json node for the 'dream' field variant from prior to 1.15"""
props = InvokeAIMetadata()
props.imported_app_version = "pre1.15"
seed_match = re.search("-S\\s*(\\d+)", dream_string)
if seed_match is not None:
try:
props.seed = int(seed_match[1])
except ValueError:
props.seed = None
raw_prompt = re.sub("(-S\\s*\\d+)", "", dream_string)
else:
raw_prompt = dream_string
pos_prompt, neg_prompt = self.split_prompt(raw_prompt)
props.positive_prompt = pos_prompt
props.negative_prompt = neg_prompt
return props
def parse_meta_tag_sd_metadata(self, tag_value):
"""Take as input an png metadata json node for the 'sd-metadata' field variant from 1.15 through 2.3.5 post 2"""
props = InvokeAIMetadata()
props.imported_app_version = tag_value.get("app_version")
props.model_name = tag_value.get("model_weights")
img_node = tag_value.get("image")
if img_node is not None:
props.generation_mode = img_node.get("type")
props.width = img_node.get("width")
props.height = img_node.get("height")
props.seed = img_node.get("seed")
props.rand_device = "cuda" # hardcoded since all generations pre 3.0 used cuda random noise instead of cpu
props.cfg_scale = img_node.get("cfg_scale")
props.steps = img_node.get("steps")
props.scheduler = self.map_scheduler(img_node.get("sampler"))
props.strength = img_node.get("strength")
if props.strength is None:
props.strength = img_node.get("strength_steps") # try second name for this property
props.init_image = img_node.get("init_image_path")
if props.init_image is None: # try second name for this property
props.init_image = img_node.get("init_img")
# remove the path info from init_image so if we move the init image, it will be correctly relative in the new location
if props.init_image is not None:
props.init_image = os.path.basename(props.init_image)
raw_prompt = img_node.get("prompt")
if isinstance(raw_prompt, list):
raw_prompt = raw_prompt[0].get("prompt")
props.positive_prompt, props.negative_prompt = self.split_prompt(raw_prompt)
return props
def parse_meta_tag_invokeai(self, tag_value):
"""Take as input an png metadata json node for the 'invokeai' field variant from 3.0.0 beta 1 through 5"""
props = InvokeAIMetadata()
props.imported_app_version = "3.0.0 or later"
props.generation_mode = tag_value.get("type")
if props.generation_mode is not None:
props.generation_mode = props.generation_mode.replace("t2l", "txt2img").replace("l2l", "img2img")
props.width = tag_value.get("width")
props.height = tag_value.get("height")
props.seed = tag_value.get("seed")
props.cfg_scale = tag_value.get("cfg_scale")
props.steps = tag_value.get("steps")
props.scheduler = tag_value.get("scheduler")
props.strength = tag_value.get("strength")
props.positive_prompt = tag_value.get("positive_conditioning")
props.negative_prompt = tag_value.get("negative_conditioning")
return props
def map_scheduler(self, old_scheduler):
"""Convert the legacy sampler names to matching 3.0 schedulers"""
if old_scheduler is None:
return None
match (old_scheduler):
case "ddim":
return "ddim"
case "plms":
return "pnmd"
case "k_lms":
return "lms"
case "k_dpm_2":
return "kdpm_2"
case "k_dpm_2_a":
return "kdpm_2_a"
case "dpmpp_2":
return "dpmpp_2s"
case "k_dpmpp_2":
return "dpmpp_2m"
case "k_dpmpp_2_a":
return None # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
case "k_euler":
return "euler"
case "k_euler_a":
return "euler_a"
case "k_heun":
return "heun"
return None
def split_prompt(self, raw_prompt: str):
"""Split the unified prompt strings by extracting all negative prompt blocks out into the negative prompt."""
if raw_prompt is None:
return "", ""
raw_prompt_search = raw_prompt.replace("\r", "").replace("\n", "")
matches = re.findall(r"\[(.+?)\]", raw_prompt_search)
if len(matches) > 0:
negative_prompt = ""
if len(matches) == 1:
negative_prompt = matches[0].strip().strip(",")
else:
for match in matches:
negative_prompt += f"({match.strip().strip(',')})"
positive_prompt = re.sub(r"(\[.+?\])", "", raw_prompt_search).strip()
else:
positive_prompt = raw_prompt_search.strip()
negative_prompt = ""
return positive_prompt, negative_prompt
class DatabaseMapper:
"""Class to abstract database functionality."""
def __init__(self, database_path, database_backup_dir):
self.database_path = database_path
self.database_backup_dir = database_backup_dir
self.connection = None
self.cursor = None
def connect(self):
"""Open connection to the database."""
self.connection = sqlite3.connect(self.database_path)
self.cursor = self.connection.cursor()
def get_board_names(self):
"""Get a list of the current board names from the database."""
sql_get_board_name = "SELECT board_name FROM boards"
self.cursor.execute(sql_get_board_name)
rows = self.cursor.fetchall()
return [row[0] for row in rows]
def does_image_exist(self, image_name):
"""Check database if a image name already exists and return a boolean."""
sql_get_image_by_name = f"SELECT image_name FROM images WHERE image_name='{image_name}'"
self.cursor.execute(sql_get_image_by_name)
rows = self.cursor.fetchall()
return True if len(rows) > 0 else False
def add_new_image_to_database(self, filename, width, height, metadata, modified_date_string):
"""Add an image to the database."""
sql_add_image = f"""INSERT INTO images (image_name, image_origin, image_category, width, height, session_id, node_id, metadata, is_intermediate, created_at, updated_at)
VALUES ('{filename}', 'internal', 'general', {width}, {height}, null, null, '{metadata}', 0, '{modified_date_string}', '{modified_date_string}')"""
self.cursor.execute(sql_add_image)
self.connection.commit()
def get_board_id_with_create(self, board_name):
"""Get the board id for supplied name, and create the board if one does not exist."""
sql_find_board = f"SELECT board_id FROM boards WHERE board_name='{board_name}' COLLATE NOCASE"
self.cursor.execute(sql_find_board)
rows = self.cursor.fetchall()
if len(rows) > 0:
return rows[0][0]
else:
board_date_string = datetime.datetime.utcnow().date().isoformat()
new_board_id = str(uuid.uuid4())
sql_insert_board = f"INSERT INTO boards (board_id, board_name, created_at, updated_at) VALUES ('{new_board_id}', '{board_name}', '{board_date_string}', '{board_date_string}')"
self.cursor.execute(sql_insert_board)
self.connection.commit()
return new_board_id
def add_image_to_board(self, filename, board_id):
"""Add an image mapping to a board."""
add_datetime_str = datetime.datetime.utcnow().isoformat()
sql_add_image_to_board = f"""INSERT INTO board_images (board_id, image_name, created_at, updated_at)
VALUES ('{board_id}', '{filename}', '{add_datetime_str}', '{add_datetime_str}')"""
self.cursor.execute(sql_add_image_to_board)
self.connection.commit()
def disconnect(self):
"""Disconnect from the db, cleaning up connections and cursors."""
if self.cursor is not None:
self.cursor.close()
if self.connection is not None:
self.connection.close()
def backup(self, timestamp_string):
"""Take a backup of the database."""
if not os.path.exists(self.database_backup_dir):
print(f"Database backup directory {self.database_backup_dir} does not exist -> creating...", end="")
os.makedirs(self.database_backup_dir)
print("Done!")
database_backup_path = os.path.join(self.database_backup_dir, f"backup-{timestamp_string}-invokeai.db")
print(f"Making DB Backup at {database_backup_path}...", end="")
shutil.copy2(self.database_path, database_backup_path)
print("Done!")
class MediaImportProcessor:
"""Containing class for script functionality."""
def __init__(self):
pass
board_name_id_map = {}
def get_import_file_list(self):
"""Ask the user for the import folder and scan for the list of files to return."""
while True:
default = ""
while True:
import_dir = os.path.expanduser(
prompt(
"Inputs: Specify absolute path containing InvokeAI .png images to import: ",
completer=PathCompleter(expanduser=True, only_directories=True),
default=default,
)
)
if len(import_dir) > 0 and Path(import_dir).is_dir():
break
default = import_dir
recurse_directories = (
(prompt("Include files from subfolders recursively [yN]? ").strip() or "N").upper().startswith("N")
)
if recurse_directories:
is_recurse = False
matching_file_list = glob.glob(import_dir + "/*.png", recursive=False)
else:
is_recurse = True
matching_file_list = glob.glob(import_dir + "/**/*.png", recursive=True)
if len(matching_file_list) > 0:
return import_dir, is_recurse, matching_file_list
else:
print(f"The specific path {import_dir} exists, but does not contain .png files!")
def get_file_details(self, filepath):
"""Retrieve the embedded metedata fields and dimensions from an image file."""
with PIL.Image.open(filepath) as img:
img.load()
png_width, png_height = img.size
img_info = img.info
return img_info, png_width, png_height
def select_board_option(self, board_names, timestamp_string):
"""Allow the user to choose how a board is selected for imported files."""
while True:
print("\r\nOptions for board selection for imported images:")
print(f"1) Select an existing board name. (found {len(board_names)})")
print("2) Specify a board name to create/add to.")
print("3) Create/add to board named 'IMPORT'.")
print(
f"4) Create/add to board named 'IMPORT' with the current datetime string appended (.e.g IMPORT_{timestamp_string})."
)
print(
"5) Create/add to board named 'IMPORT' with a the original file app_version appended (.e.g IMPORT_2.2.5)."
)
input_option = input("Specify desired board option: ")
match (input_option):
case "1":
if len(board_names) < 1:
print("\r\nThere are no existing board names to choose from. Select another option!")
continue
board_name = self.select_item_from_list(
board_names, "board name", True, "Cancel, go back and choose a different board option."
)
if board_name is not None:
return board_name
case "2":
while True:
board_name = input("Specify new/existing board name: ")
if board_name:
return board_name
case "3":
return "IMPORT"
case "4":
return f"IMPORT_{timestamp_string}"
case "5":
return "IMPORT_APPVERSION"
def select_item_from_list(self, items, entity_name, allow_cancel, cancel_string):
"""A general function to render a list of items to select in the console, prompt the user for a selection and ensure a valid entry is selected."""
print(f"Select a {entity_name.lower()} from the following list:")
index = 1
for item in items:
print(f"{index}) {item}")
index += 1
if allow_cancel:
print(f"{index}) {cancel_string}")
while True:
try:
option_number = int(input("Specify number of selection: "))
except ValueError:
continue
if allow_cancel and option_number == index:
return None
if option_number >= 1 and option_number <= len(items):
return items[option_number - 1]
def import_image(self, filepath: str, board_name_option: str, db_mapper: DatabaseMapper, config: Config):
"""Import a single file by its path"""
parser = InvokeAIMetadataParser()
file_name = os.path.basename(filepath)
file_destination_path = os.path.join(config.outputs_path, file_name)
print("===============================================================================")
print(f"Importing {filepath}")
# check destination to see if the file was previously imported
if os.path.exists(file_destination_path):
print("File already exists in the destination, skipping!")
ImportStats.count_skipped_file_exists += 1
return
# check if file name is already referenced in the database
if db_mapper.does_image_exist(file_name):
print("A reference to a file with this name already exists in the database, skipping!")
ImportStats.count_skipped_db_exists += 1
return
# load image info and dimensions
img_info, png_width, png_height = self.get_file_details(filepath)
# parse metadata
destination_needs_meta_update = True
log_version_note = "(Unknown)"
if "invokeai_metadata" in img_info:
# for the latest, we will just re-emit the same json, no need to parse/modify
converted_field = None
latest_json_string = img_info.get("invokeai_metadata")
log_version_note = "3.0.0+"
destination_needs_meta_update = False
else:
if "sd-metadata" in img_info:
converted_field = parser.parse_meta_tag_sd_metadata(json.loads(img_info.get("sd-metadata")))
elif "invokeai" in img_info:
converted_field = parser.parse_meta_tag_invokeai(json.loads(img_info.get("invokeai")))
elif "dream" in img_info:
converted_field = parser.parse_meta_tag_dream(img_info.get("dream"))
elif "Dream" in img_info:
converted_field = parser.parse_meta_tag_dream(img_info.get("Dream"))
else:
converted_field = InvokeAIMetadata()
destination_needs_meta_update = False
print("File does not have metadata from known Invoke AI versions, add only, no update!")
# use the loaded img dimensions if the metadata didnt have them
if converted_field.width is None:
converted_field.width = png_width
if converted_field.height is None:
converted_field.height = png_height
log_version_note = converted_field.imported_app_version if converted_field else "NoVersion"
log_version_note = log_version_note or "NoVersion"
latest_json_string = converted_field.to_json()
print(f"From Invoke AI Version {log_version_note} with dimensions {png_width} x {png_height}.")
# if metadata needs update, then update metdata and copy in one shot
if destination_needs_meta_update:
print("Updating metadata while copying...", end="")
self.update_file_metadata_while_copying(
filepath, file_destination_path, "invokeai_metadata", latest_json_string
)
print("Done!")
else:
print("No metadata update necessary, copying only...", end="")
shutil.copy2(filepath, file_destination_path)
print("Done!")
# create thumbnail
print("Creating thumbnail...", end="")
thumbnail_path = os.path.join(config.thumbnail_path, os.path.splitext(file_name)[0]) + ".webp"
thumbnail_size = 256, 256
with PIL.Image.open(filepath) as source_image:
source_image.thumbnail(thumbnail_size)
source_image.save(thumbnail_path, "webp")
print("Done!")
# finalize the dynamic board name if there is an APPVERSION token in it.
if converted_field is not None:
board_name = board_name_option.replace("APPVERSION", converted_field.imported_app_version or "NoVersion")
else:
board_name = board_name_option.replace("APPVERSION", "Latest")
# maintain a map of alrady created/looked up ids to avoid DB queries
print("Finding/Creating board...", end="")
if board_name in self.board_name_id_map:
board_id = self.board_name_id_map[board_name]
else:
board_id = db_mapper.get_board_id_with_create(board_name)
self.board_name_id_map[board_name] = board_id
print("Done!")
# add image to db
print("Adding image to database......", end="")
modified_time = datetime.datetime.utcfromtimestamp(os.path.getmtime(filepath))
db_mapper.add_new_image_to_database(file_name, png_width, png_height, latest_json_string, modified_time)
print("Done!")
# add image to board
print("Adding image to board......", end="")
db_mapper.add_image_to_board(file_name, board_id)
print("Done!")
ImportStats.count_imported += 1
if log_version_note in ImportStats.count_imported_by_version:
ImportStats.count_imported_by_version[log_version_note] += 1
else:
ImportStats.count_imported_by_version[log_version_note] = 1
def update_file_metadata_while_copying(self, filepath, file_destination_path, tag_name, tag_value):
"""Perform a metadata update with save to a new destination which accomplishes a copy while updating metadata."""
with PIL.Image.open(filepath) as target_image:
existing_img_info = target_image.info
metadata = PIL.PngImagePlugin.PngInfo()
# re-add any existing invoke ai tags unless they are the one we are trying to add
for key in existing_img_info:
if key != tag_name and key in ("dream", "Dream", "sd-metadata", "invokeai", "invokeai_metadata"):
metadata.add_text(key, existing_img_info[key])
metadata.add_text(tag_name, tag_value)
target_image.save(file_destination_path, pnginfo=metadata)
def process(self):
"""Begin main processing."""
print("===============================================================================")
print("This script will import images generated by earlier versions of")
print("InvokeAI into the currently installed root directory:")
print(f" {app_config.root_path}")
print("If this is not what you want to do, type ctrl-C now to cancel.")
# load config
print("===============================================================================")
print("= Configuration & Settings")
config = Config()
config.find_and_load()
db_mapper = DatabaseMapper(config.database_path, config.database_backup_dir)
db_mapper.connect()
import_dir, is_recurse, import_file_list = self.get_import_file_list()
ImportStats.count_source_files = len(import_file_list)
board_names = db_mapper.get_board_names()
board_name_option = self.select_board_option(board_names, config.TIMESTAMP_STRING)
print("\r\n===============================================================================")
print("= Import Settings Confirmation")
print()
print(f"Database File Path : {config.database_path}")
print(f"Outputs/Images Directory : {config.outputs_path}")
print(f"Import Image Source Directory : {import_dir}")
print(f" Recurse Source SubDirectories : {'Yes' if is_recurse else 'No'}")
print(f"Count of .png file(s) found : {len(import_file_list)}")
print(f"Board name option specified : {board_name_option}")
print(f"Database backup will be taken at : {config.database_backup_dir}")
print("\r\nNotes about the import process:")
print("- Source image files will not be modified, only copied to the outputs directory.")
print("- If the same file name already exists in the destination, the file will be skipped.")
print("- If the same file name already has a record in the database, the file will be skipped.")
print("- Invoke AI metadata tags will be updated/written into the imported copy only.")
print(
"- On the imported copy, only Invoke AI known tags (latest and legacy) will be retained (dream, sd-metadata, invokeai, invokeai_metadata)"
)
print(
"- A property 'imported_app_version' will be added to metadata that can be viewed in the UI's metadata viewer."
)
print(
"- The new 3.x InvokeAI outputs folder structure is flat so recursively found source imges will all be placed into the single outputs/images folder."
)
while True:
should_continue = prompt("\nDo you wish to continue with the import [Yn] ? ").lower() or "y"
if should_continue == "n":
print("\r\nCancelling Import")
return
elif should_continue == "y":
print()
break
db_mapper.backup(config.TIMESTAMP_STRING)
print()
ImportStats.time_start = datetime.datetime.utcnow()
for filepath in import_file_list:
try:
self.import_image(filepath, board_name_option, db_mapper, config)
except sqlite3.Error as sql_ex:
print(f"A database related exception was found processing {filepath}, will continue to next file. ")
print("Exception detail:")
print(sql_ex)
ImportStats.count_file_errors += 1
except Exception as ex:
print(f"Exception processing {filepath}, will continue to next file. ")
print("Exception detail:")
print(ex)
ImportStats.count_file_errors += 1
print("\r\n===============================================================================")
print(f"= Import Complete - Elpased Time: {ImportStats.get_elapsed_time_string()}")
print()
print(f"Source File(s) : {ImportStats.count_source_files}")
print(f"Total Imported : {ImportStats.count_imported}")
print(f"Skipped b/c file already exists on disk : {ImportStats.count_skipped_file_exists}")
print(f"Skipped b/c file already exists in db : {ImportStats.count_skipped_db_exists}")
print(f"Errors during import : {ImportStats.count_file_errors}")
if ImportStats.count_imported > 0:
print("\r\nBreakdown of imported files by version:")
for key, version in ImportStats.count_imported_by_version.items():
print(f" {key:20} : {version}")
def main():
try:
processor = MediaImportProcessor()
processor.process()
except KeyboardInterrupt:
print("\r\n\r\nUser cancelled execution.")
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,4 @@
""" """
Wrapper for invokeai.backend.configure.invokeai_configure Wrapper for invokeai.backend.configure.invokeai_configure
""" """
from ...backend.install.invokeai_configure import main as invokeai_configure from ...backend.install.invokeai_configure import main

View File

@@ -28,6 +28,7 @@ from npyscreen import widget
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.install.model_install_backend import ( from invokeai.backend.install.model_install_backend import (
ModelInstallList,
InstallSelections, InstallSelections,
ModelInstall, ModelInstall,
SchedulerPredictionType, SchedulerPredictionType,
@@ -40,12 +41,12 @@ from invokeai.frontend.install.widgets import (
SingleSelectColumns, SingleSelectColumns,
TextBox, TextBox,
BufferBox, BufferBox,
FileBox,
set_min_terminal_size, set_min_terminal_size,
select_stable_diffusion_config_file, select_stable_diffusion_config_file,
CyclingForm, CyclingForm,
MIN_COLS, MIN_COLS,
MIN_LINES, MIN_LINES,
WindowTooSmallException,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
@@ -155,7 +156,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
BufferBox, BufferBox,
name="Log Messages", name="Log Messages",
editable=False, editable=False,
max_height=6, max_height=15,
) )
self.nextrely += 1 self.nextrely += 1
@@ -692,11 +693,7 @@ def select_and_download_models(opt: Namespace):
# needed to support the probe() method running under a subprocess # needed to support the probe() method running under a subprocess
torch.multiprocessing.set_start_method("spawn") torch.multiprocessing.set_start_method("spawn")
if not set_min_terminal_size(MIN_COLS, MIN_LINES): set_min_terminal_size(MIN_COLS, MIN_LINES)
raise WindowTooSmallException(
"Could not increase terminal size. Try running again with a larger window or smaller font size."
)
installApp = AddModelApplication(opt) installApp = AddModelApplication(opt)
try: try:
installApp.run() installApp.run()
@@ -790,8 +787,6 @@ def main():
curses.echo() curses.echo()
curses.endwin() curses.endwin()
logger.info("Goodbye! Come back soon.") logger.info("Goodbye! Come back soon.")
except WindowTooSmallException as e:
logger.error(str(e))
except widget.NotEnoughSpaceForWidget as e: except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"): if str(e).startswith("Height of 1 allocated"):
logger.error("Insufficient vertical space for the interface. Please make your window taller and try again") logger.error("Insufficient vertical space for the interface. Please make your window taller and try again")

View File

@@ -21,19 +21,13 @@ MIN_COLS = 130
MIN_LINES = 38 MIN_LINES = 38
class WindowTooSmallException(Exception):
pass
# ------------------------------------- # -------------------------------------
def set_terminal_size(columns: int, lines: int) -> bool: def set_terminal_size(columns: int, lines: int):
OS = platform.uname().system
screen_ok = False
while not screen_ok:
ts = get_terminal_size() ts = get_terminal_size()
width = max(columns, ts.columns) width = max(columns, ts.columns)
height = max(lines, ts.lines) height = max(lines, ts.lines)
OS = platform.uname().system
if OS == "Windows": if OS == "Windows":
pass pass
# not working reliably - ask user to adjust the window # not working reliably - ask user to adjust the window
@@ -43,18 +37,15 @@ def set_terminal_size(columns: int, lines: int) -> bool:
# check whether it worked.... # check whether it worked....
ts = get_terminal_size() ts = get_terminal_size()
if ts.columns < columns or ts.lines < lines: pause = False
print( if ts.columns < columns:
f"\033[1mThis window is too small for the interface. InvokeAI requires {columns}x{lines} (w x h) characters, but window is {ts.columns}x{ts.lines}\033[0m" print("\033[1mThis window is too narrow for the user interface.\033[0m")
) pause = True
resp = input( if ts.lines < lines:
"Maximize the window and/or decrease the font size then press any key to continue. Type [Q] to give up.." print("\033[1mThis window is too short for the user interface.\033[0m")
) pause = True
if resp.upper().startswith("Q"): if pause:
break input("Maximize the window then press any key to continue..")
else:
screen_ok = True
return screen_ok
def _set_terminal_size_powershell(width: int, height: int): def _set_terminal_size_powershell(width: int, height: int):
@@ -89,14 +80,14 @@ def _set_terminal_size_unix(width: int, height: int):
sys.stdout.flush() sys.stdout.flush()
def set_min_terminal_size(min_cols: int, min_lines: int) -> bool: def set_min_terminal_size(min_cols: int, min_lines: int):
# make sure there's enough room for the ui # make sure there's enough room for the ui
term_cols, term_lines = get_terminal_size() term_cols, term_lines = get_terminal_size()
if term_cols >= min_cols and term_lines >= min_lines: if term_cols >= min_cols and term_lines >= min_lines:
return True return
cols = max(term_cols, min_cols) cols = max(term_cols, min_cols)
lines = max(term_lines, min_lines) lines = max(term_lines, min_lines)
return set_terminal_size(cols, lines) set_terminal_size(cols, lines)
class IntSlider(npyscreen.Slider): class IntSlider(npyscreen.Slider):
@@ -173,7 +164,7 @@ class FloatSlider(npyscreen.Slider):
class FloatTitleSlider(npyscreen.TitleText): class FloatTitleSlider(npyscreen.TitleText):
_entry_type = npyscreen.Slider _entry_type = FloatSlider
class SelectColumnBase: class SelectColumnBase:

View File

@@ -382,7 +382,6 @@ def run_cli(args: Namespace):
def main(): def main():
args = _parse_args() args = _parse_args()
if args.root_dir:
config.parse_args(["--root", str(args.root_dir)]) config.parse_args(["--root", str(args.root_dir)])
try: try:

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,4 +1,4 @@
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-815faab3.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-e9f8a36e.js";var za=String.raw,Ca=za` import{A as m,f$ as Je,z as y,a4 as Ka,g0 as Xa,af as va,aj as d,g1 as b,g2 as t,g3 as Ya,g4 as h,g5 as ua,g6 as Ja,g7 as Qa,aI as Za,g8 as et,ad as rt,g9 as at}from"./index-18f2f740.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./MantineProvider-b20a2267.js";var za=String.raw,Ca=za`
:root, :root,
:host { :host {
--chakra-vh: 100vh; --chakra-vh: 100vh;

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-815faab3.js"></script> <script type="module" crossorigin src="./assets/index-18f2f740.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@@ -124,8 +124,7 @@
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.", "deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored.", "deleteImagePermanent": "Deleted images cannot be restored.",
"images": "Images", "images": "Images",
"assets": "Assets", "assets": "Assets"
"autoAssignBoardOnClick": "Auto-Assign Board on Click"
}, },
"hotkeys": { "hotkeys": {
"keyboardShortcuts": "Keyboard Shortcuts", "keyboardShortcuts": "Keyboard Shortcuts",

View File

@@ -23,7 +23,7 @@
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"", "dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"", "dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build", "build": "yarn run lint && vite build",
"typegen": "node scripts/typegen.js", "typegen": "npx ts-node scripts/typegen.ts",
"preview": "vite preview", "preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx", "lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .", "lint:eslint": "eslint --max-warnings=0 .",

View File

@@ -124,8 +124,7 @@
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.", "deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored.", "deleteImagePermanent": "Deleted images cannot be restored.",
"images": "Images", "images": "Images",
"assets": "Assets", "assets": "Assets"
"autoAssignBoardOnClick": "Auto-Assign Board on Click"
}, },
"hotkeys": { "hotkeys": {
"keyboardShortcuts": "Keyboard Shortcuts", "keyboardShortcuts": "Keyboard Shortcuts",

View File

@@ -4,9 +4,8 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { PartialAppConfig } from 'app/types/invokeai'; import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader'; import ImageUploader from 'common/components/ImageUploader';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import GalleryDrawer from 'features/gallery/components/GalleryPanel'; import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
import SiteHeader from 'features/system/components/SiteHeader'; import SiteHeader from 'features/system/components/SiteHeader';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors'; import { languageSelector } from 'features/system/store/systemSelectors';
@@ -17,6 +16,7 @@ import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import i18n from 'i18n'; import i18n from 'i18n';
import { size } from 'lodash-es'; import { size } from 'lodash-es';
import { ReactNode, memo, useEffect } from 'react'; import { ReactNode, memo, useEffect } from 'react';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import GlobalHotkeys from './GlobalHotkeys'; import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster'; import Toaster from './Toaster';
@@ -84,7 +84,7 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
</Portal> </Portal>
</Grid> </Grid>
<DeleteImageModal /> <DeleteImageModal />
<ChangeBoardModal /> <UpdateImageBoardModal />
<Toaster /> <Toaster />
<GlobalHotkeys /> <GlobalHotkeys />
</> </>

View File

@@ -58,7 +58,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
); );
} }
if (props.dragData.payloadType === 'IMAGE_DTOS') { if (props.dragData.payloadType === 'IMAGE_NAMES') {
return ( return (
<Flex <Flex
sx={{ sx={{
@@ -71,7 +71,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
...STYLES, ...STYLES,
}} }}
> >
<Heading>{props.dragData.payload.imageDTOs.length}</Heading> <Heading>{props.dragData.payload.image_names.length}</Heading>
<Heading size="sm">Images</Heading> <Heading size="sm">Images</Heading>
</Flex> </Flex>
); );

View File

@@ -18,32 +18,27 @@ import {
DragStartEvent, DragStartEvent,
TypesafeDraggableData, TypesafeDraggableData,
} from './typesafeDnd'; } from './typesafeDnd';
import { logger } from 'app/logging/logger';
type ImageDndContextProps = PropsWithChildren; type ImageDndContextProps = PropsWithChildren;
const ImageDndContext = (props: ImageDndContextProps) => { const ImageDndContext = (props: ImageDndContextProps) => {
const [activeDragData, setActiveDragData] = const [activeDragData, setActiveDragData] =
useState<TypesafeDraggableData | null>(null); useState<TypesafeDraggableData | null>(null);
const log = logger('images');
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleDragStart = useCallback( const handleDragStart = useCallback((event: DragStartEvent) => {
(event: DragStartEvent) => { console.log('dragStart', event.active.data.current);
log.trace({ dragData: event.active.data.current }, 'Drag started');
const activeData = event.active.data.current; const activeData = event.active.data.current;
if (!activeData) { if (!activeData) {
return; return;
} }
setActiveDragData(activeData); setActiveDragData(activeData);
}, }, []);
[log]
);
const handleDragEnd = useCallback( const handleDragEnd = useCallback(
(event: DragEndEvent) => { (event: DragEndEvent) => {
log.trace({ dragData: event.active.data.current }, 'Drag ended'); console.log('dragEnd', event.active.data.current);
const overData = event.over?.data.current; const overData = event.over?.data.current;
if (!activeDragData || !overData) { if (!activeDragData || !overData) {
return; return;
@@ -51,7 +46,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
dispatch(dndDropped({ overData, activeData: activeDragData })); dispatch(dndDropped({ overData, activeData: activeDragData }));
setActiveDragData(null); setActiveDragData(null);
}, },
[activeDragData, dispatch, log] [activeDragData, dispatch]
); );
const mouseSensor = useSensor(MouseSensor, { const mouseSensor = useSensor(MouseSensor, {

View File

@@ -11,6 +11,7 @@ import {
useDraggable as useOriginalDraggable, useDraggable as useOriginalDraggable,
useDroppable as useOriginalDroppable, useDroppable as useOriginalDroppable,
} from '@dnd-kit/core'; } from '@dnd-kit/core';
import { BoardId } from 'features/gallery/store/types';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
type BaseDropData = { type BaseDropData = {
@@ -53,13 +54,9 @@ export type AddToBatchDropData = BaseDropData & {
actionType: 'ADD_TO_BATCH'; actionType: 'ADD_TO_BATCH';
}; };
export type AddToBoardDropData = BaseDropData & { export type MoveBoardDropData = BaseDropData & {
actionType: 'ADD_TO_BOARD'; actionType: 'MOVE_BOARD';
context: { boardId: string }; context: { boardId: BoardId };
};
export type RemoveFromBoardDropData = BaseDropData & {
actionType: 'REMOVE_FROM_BOARD';
}; };
export type TypesafeDroppableData = export type TypesafeDroppableData =
@@ -70,8 +67,7 @@ export type TypesafeDroppableData =
| NodesImageDropData | NodesImageDropData
| AddToBatchDropData | AddToBatchDropData
| NodesMultiImageDropData | NodesMultiImageDropData
| AddToBoardDropData | MoveBoardDropData;
| RemoveFromBoardDropData;
type BaseDragData = { type BaseDragData = {
id: string; id: string;
@@ -82,12 +78,14 @@ export type ImageDraggableData = BaseDragData & {
payload: { imageDTO: ImageDTO }; payload: { imageDTO: ImageDTO };
}; };
export type ImageDTOsDraggableData = BaseDragData & { export type ImageNamesDraggableData = BaseDragData & {
payloadType: 'IMAGE_DTOS'; payloadType: 'IMAGE_NAMES';
payload: { imageDTOs: ImageDTO[] }; payload: { image_names: string[] };
}; };
export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData; export type TypesafeDraggableData =
| ImageDraggableData
| ImageNamesDraggableData;
interface UseDroppableTypesafeArguments interface UseDroppableTypesafeArguments
extends Omit<UseDroppableArguments, 'data'> { extends Omit<UseDroppableArguments, 'data'> {
@@ -158,39 +156,14 @@ export const isValidDrop = (
case 'SET_NODES_IMAGE': case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO'; return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE': case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
case 'ADD_TO_BATCH': case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
case 'ADD_TO_BOARD': { case 'MOVE_BOARD': {
// If the board is the same, don't allow the drop // If the board is the same, don't allow the drop
// Check the payload types // Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
const destinationBoard = overData.context.boardId;
return currentBoard !== destinationBoard;
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
case 'REMOVE_FROM_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) { if (!isPayloadValid) {
return false; return false;
} }
@@ -199,16 +172,20 @@ export const isValidDrop = (
if (payloadType === 'IMAGE_DTO') { if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload; const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id; const currentBoard = imageDTO.board_id;
const destinationBoard = overData.context.boardId;
return currentBoard !== 'none'; const isSameBoard = currentBoard === destinationBoard;
const isDestinationValid = !currentBoard ? destinationBoard : true;
return !isSameBoard && isDestinationValid;
} }
if (payloadType === 'IMAGE_DTOS') { if (payloadType === 'IMAGE_NAMES') {
// TODO (multi-select) // TODO (multi-select)
return true; return false;
} }
return false; return true;
} }
default: default:
return false; return false;

View File

@@ -1,6 +1,4 @@
import { Middleware } from '@reduxjs/toolkit';
import { store } from 'app/store/store'; import { store } from 'app/store/store';
import { PartialAppConfig } from 'app/types/invokeai';
import React, { import React, {
lazy, lazy,
memo, memo,
@@ -9,11 +7,16 @@ import React, {
useEffect, useEffect,
} from 'react'; } from 'react';
import { Provider } from 'react-redux'; import { Provider } from 'react-redux';
import { PartialAppConfig } from 'app/types/invokeai';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares'; import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
import { $authToken, $baseUrl, $projectId } from 'services/api/client';
import { socketMiddleware } from 'services/events/middleware';
import Loading from '../../common/components/Loading/Loading'; import Loading from '../../common/components/Loading/Loading';
import { Middleware } from '@reduxjs/toolkit';
import { $authToken, $baseUrl } from 'services/api/client';
import { socketMiddleware } from 'services/events/middleware';
import '../../i18n'; import '../../i18n';
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
import ImageDndContext from './ImageDnd/ImageDndContext'; import ImageDndContext from './ImageDnd/ImageDndContext';
const App = lazy(() => import('./App')); const App = lazy(() => import('./App'));
@@ -34,7 +37,6 @@ const InvokeAIUI = ({
config, config,
headerComponent, headerComponent,
middleware, middleware,
projectId,
}: Props) => { }: Props) => {
useEffect(() => { useEffect(() => {
// configure API client token // configure API client token
@@ -47,11 +49,6 @@ const InvokeAIUI = ({
$baseUrl.set(apiUrl); $baseUrl.set(apiUrl);
} }
// configure API client project header
if (projectId) {
$projectId.set(projectId);
}
// reset dynamically added middlewares // reset dynamically added middlewares
resetMiddlewares(); resetMiddlewares();
@@ -71,9 +68,8 @@ const InvokeAIUI = ({
// Reset the API client token and base url on unmount // Reset the API client token and base url on unmount
$baseUrl.set(undefined); $baseUrl.set(undefined);
$authToken.set(undefined); $authToken.set(undefined);
$projectId.set(undefined);
}; };
}, [apiUrl, token, middleware, projectId]); }, [apiUrl, token, middleware]);
return ( return (
<React.StrictMode> <React.StrictMode>
@@ -81,7 +77,9 @@ const InvokeAIUI = ({
<React.Suspense fallback={<Loading />}> <React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider> <ThemeLocaleProvider>
<ImageDndContext> <ImageDndContext>
<AddImageToBoardContextProvider>
<App config={config} headerComponent={headerComponent} /> <App config={config} headerComponent={headerComponent} />
</AddImageToBoardContextProvider>
</ImageDndContext> </ImageDndContext>
</ThemeLocaleProvider> </ThemeLocaleProvider>
</React.Suspense> </React.Suspense>

View File

@@ -0,0 +1,91 @@
import { useDisclosure } from '@chakra-ui/react';
import { PropsWithChildren, createContext, useCallback, useState } from 'react';
import { ImageDTO } from 'services/api/types';
import { imagesApi } from 'services/api/endpoints/images';
import { useAppDispatch } from '../store/storeHooks';
export type ImageUsage = {
isInitialImage: boolean;
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
};
type AddImageToBoardContextValue = {
/**
* Whether the move image dialog is open.
*/
isOpen: boolean;
/**
* Closes the move image dialog.
*/
onClose: () => void;
/**
* The image pending movement
*/
image?: ImageDTO;
onClickAddToBoard: (image: ImageDTO) => void;
handleAddToBoard: (boardId: string) => void;
};
export const AddImageToBoardContext =
createContext<AddImageToBoardContextValue>({
isOpen: false,
onClose: () => undefined,
onClickAddToBoard: () => undefined,
handleAddToBoard: () => undefined,
});
type Props = PropsWithChildren;
export const AddImageToBoardContextProvider = (props: Props) => {
const [imageToMove, setImageToMove] = useState<ImageDTO>();
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
// Clean up after deleting or dismissing the modal
const closeAndClearImageToDelete = useCallback(() => {
setImageToMove(undefined);
onClose();
}, [onClose]);
const onClickAddToBoard = useCallback(
(image?: ImageDTO) => {
if (!image) {
return;
}
setImageToMove(image);
onOpen();
},
[setImageToMove, onOpen]
);
const handleAddToBoard = useCallback(
(boardId: string) => {
if (imageToMove) {
dispatch(
imagesApi.endpoints.addImageToBoard.initiate({
imageDTO: imageToMove,
board_id: boardId,
})
);
closeAndClearImageToDelete();
}
},
[dispatch, closeAndClearImageToDelete, imageToMove]
);
return (
<AddImageToBoardContext.Provider
value={{
isOpen,
image: imageToMove,
onClose: closeAndClearImageToDelete,
onClickAddToBoard,
handleAddToBoard,
}}
>
{props.children}
</AddImageToBoardContext.Provider>
);
};

View File

@@ -0,0 +1,8 @@
import { createContext } from 'react';
type VoidFunc = () => void;
type ImageUploaderTriggerContextType = VoidFunc | null;
export const ImageUploaderTriggerContext =
createContext<ImageUploaderTriggerContextType>(null);

View File

@@ -23,6 +23,6 @@ const serializationDenylist: {
}; };
export const serialize: SerializeFunction = (data, key) => { export const serialize: SerializeFunction = (data, key) => {
const result = omit(data, serializationDenylist[key] ?? []); const result = omit(data, serializationDenylist[key]);
return JSON.stringify(result); return JSON.stringify(result);
}; };

View File

@@ -27,8 +27,7 @@ import {
addImageDeletedFulfilledListener, addImageDeletedFulfilledListener,
addImageDeletedPendingListener, addImageDeletedPendingListener,
addImageDeletedRejectedListener, addImageDeletedRejectedListener,
addRequestedSingleImageDeletionListener, addRequestedImageDeletionListener,
addRequestedMultipleImageDeletionListener,
} from './listeners/imageDeleted'; } from './listeners/imageDeleted';
import { addImageDroppedListener } from './listeners/imageDropped'; import { addImageDroppedListener } from './listeners/imageDropped';
import { import {
@@ -112,8 +111,7 @@ addImageUploadedRejectedListener();
addInitialImageSelectedListener(); addInitialImageSelectedListener();
// Image deleted // Image deleted
addRequestedSingleImageDeletionListener(); addRequestedImageDeletionListener();
addRequestedMultipleImageDeletionListener();
addImageDeletedPendingListener(); addImageDeletedPendingListener();
addImageDeletedFulfilledListener(); addImageDeletedFulfilledListener();
addImageDeletedRejectedListener(); addImageDeletedRejectedListener();

View File

@@ -1,10 +1,12 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images'; import {
ImageCache,
getListImagesUrl,
imagesApi,
} from 'services/api/endpoints/images';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { getListImagesUrl, imagesAdapter } from 'services/api/util';
import { ImageCache } from 'services/api/types';
export const appStarted = createAction('app/appStarted'); export const appStarted = createAction('app/appStarted');
@@ -32,8 +34,7 @@ export const addFirstListImagesListener = () => {
if (data.ids.length > 0) { if (data.ids.length > 0) {
// Select the first image // Select the first image
const firstImage = imagesAdapter.getSelectors().selectAll(data)[0]; dispatch(imageSelected(data.ids[0] as string));
dispatch(imageSelected(firstImage ?? null));
} }
}, },
}); });

View File

@@ -18,9 +18,7 @@ export const addAppConfigReceivedListener = () => {
const infillMethod = getState().generation.infillMethod; const infillMethod = getState().generation.infillMethod;
if (!infill_methods.includes(infillMethod)) { if (!infill_methods.includes(infillMethod)) {
// if there is no infill method, set it to the first one dispatch(setInfillMethod(infill_methods[0]));
// if there is no first one... god help us
dispatch(setInfillMethod(infill_methods[0] as string));
} }
if (!nsfw_methods.includes('nsfw_checker')) { if (!nsfw_methods.includes('nsfw_checker')) {

View File

@@ -1,14 +1,14 @@
import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { getImageUsage } from 'features/deleteImageModal/store/selectors'; import { getImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { boardsApi } from '../../../../../services/api/endpoints/boards';
export const addDeleteBoardAndImagesFulfilledListener = () => { export const addDeleteBoardAndImagesFulfilledListener = () => {
startAppListening({ startAppListening({
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled, matcher: boardsApi.endpoints.deleteBoardAndImages.matchFulfilled,
effect: async (action, { dispatch, getState }) => { effect: async (action, { dispatch, getState }) => {
const { deleted_images } = action.payload; const { deleted_images } = action.payload;

View File

@@ -10,7 +10,6 @@ import {
} from 'features/gallery/store/types'; } from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imagesSelectors } from 'services/api/util';
export const addBoardIdSelectedListener = () => { export const addBoardIdSelectedListener = () => {
startAppListening({ startAppListening({
@@ -53,9 +52,8 @@ export const addBoardIdSelectedListener = () => {
queryArgs queryArgs
)(getState()); )(getState());
if (boardImagesData) { if (boardImagesData?.ids.length) {
const firstImage = imagesSelectors.selectAll(boardImagesData)[0]; dispatch(imageSelected((boardImagesData.ids[0] as string) ?? null));
dispatch(imageSelected(firstImage ?? null));
} else { } else {
// board has no images - deselect // board has no images - deselect
dispatch(imageSelected(null)); dispatch(imageSelected(null));

View File

@@ -26,8 +26,6 @@ export const addCanvasSavedToGalleryListener = () => {
return; return;
} }
const { autoAddBoardId } = state.gallery;
dispatch( dispatch(
imagesApi.endpoints.uploadImage.initiate({ imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', { file: new File([blob], 'savedCanvas.png', {
@@ -35,7 +33,7 @@ export const addCanvasSavedToGalleryListener = () => {
}), }),
image_category: 'general', image_category: 'general',
is_intermediate: false, is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, board_id: state.gallery.autoAddBoardId,
crop_visible: true, crop_visible: true,
postUploadAction: { postUploadAction: {
type: 'TOAST', type: 'TOAST',

View File

@@ -31,20 +31,15 @@ const predicate: AnyListenerPredicate<RootState> = (
// do not process if the user just disabled auto-config // do not process if the user just disabled auto-config
if ( if (
prevState.controlNet.controlNets[action.payload.controlNetId] prevState.controlNet.controlNets[action.payload.controlNetId]
?.shouldAutoConfig === true .shouldAutoConfig === true
) { ) {
return false; return false;
} }
} }
const cn = state.controlNet.controlNets[action.payload.controlNetId]; const { controlImage, processorType, shouldAutoConfig } =
state.controlNet.controlNets[action.payload.controlNetId];
if (!cn) {
// something is wrong, the controlNet should exist
return false;
}
const { controlImage, processorType, shouldAutoConfig } = cn;
if (controlNetModelChanged.match(action) && !shouldAutoConfig) { if (controlNetModelChanged.match(action) && !shouldAutoConfig) {
// do not process if the action is a model change but the processor settings are dirty // do not process if the action is a model change but the processor settings are dirty
return false; return false;

View File

@@ -17,7 +17,7 @@ export const addControlNetImageProcessedListener = () => {
const { controlNetId } = action.payload; const { controlNetId } = action.payload;
const controlNet = getState().controlNet.controlNets[controlNetId]; const controlNet = getState().controlNet.controlNets[controlNetId];
if (!controlNet?.controlImage) { if (!controlNet.controlImage) {
log.error('Unable to process ControlNet image'); log.error('Unable to process ControlNet image');
return; return;
} }

View File

@@ -1,72 +1,57 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors'; import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions';
import { isModalOpenChanged } from 'features/imageDeletion/store/imageDeletionSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import { api } from 'services/api'; import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
import { startAppListening } from '..'; import { startAppListening } from '..';
export const addRequestedSingleImageDeletionListener = () => { /**
* Called when the user requests an image deletion
*/
export const addRequestedImageDeletionListener = () => {
startAppListening({ startAppListening({
actionCreator: imageDeletionConfirmed, actionCreator: imageDeletionConfirmed,
effect: async (action, { dispatch, getState, condition }) => { effect: async (action, { dispatch, getState, condition }) => {
const { imageDTOs, imagesUsage } = action.payload; const { imageDTO, imageUsage } = action.payload;
if (imageDTOs.length !== 1 || imagesUsage.length !== 1) {
// handle multiples in separate listener
return;
}
const imageDTO = imageDTOs[0];
const imageUsage = imagesUsage[0];
if (!imageDTO || !imageUsage) {
// satisfy noUncheckedIndexedAccess
return;
}
dispatch(isModalOpenChanged(false)); dispatch(isModalOpenChanged(false));
const state = getState();
const lastSelectedImage =
state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
const { image_name } = imageDTO; const { image_name } = imageDTO;
const state = getState();
const lastSelectedImage =
state.gallery.selection[state.gallery.selection.length - 1];
if (lastSelectedImage === image_name) {
const baseQueryArgs = selectListImagesBaseQueryArgs(state); const baseQueryArgs = selectListImagesBaseQueryArgs(state);
const { data } = const { data } =
imagesApi.endpoints.listImages.select(baseQueryArgs)(state); imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const cachedImageDTOs = data const ids = data?.ids ?? [];
? imagesAdapter.getSelectors().selectAll(data)
: [];
const deletedImageIndex = cachedImageDTOs.findIndex( const deletedImageIndex = ids.findIndex(
(i) => i.image_name === image_name (result) => result.toString() === image_name
); );
const filteredImageDTOs = cachedImageDTOs.filter( const filteredIds = ids.filter((id) => id.toString() !== image_name);
(i) => i.image_name !== image_name
);
const newSelectedImageIndex = clamp( const newSelectedImageIndex = clamp(
deletedImageIndex, deletedImageIndex,
0, 0,
filteredImageDTOs.length - 1 filteredIds.length - 1
); );
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex]; const newSelectedImageId = filteredIds[newSelectedImageIndex];
if (newSelectedImageDTO) { if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImageDTO)); dispatch(imageSelected(newSelectedImageId as string));
} else { } else {
dispatch(imageSelected(null)); dispatch(imageSelected(null));
} }
@@ -112,66 +97,6 @@ export const addRequestedSingleImageDeletionListener = () => {
}); });
}; };
/**
* Called when the user requests an image deletion
*/
export const addRequestedMultipleImageDeletionListener = () => {
startAppListening({
actionCreator: imageDeletionConfirmed,
effect: async (action, { dispatch, getState }) => {
const { imageDTOs, imagesUsage } = action.payload;
if (imageDTOs.length < 1 || imagesUsage.length < 1) {
// handle singles in separate listener
return;
}
try {
// Delete from server
await dispatch(
imagesApi.endpoints.deleteImages.initiate({ imageDTOs })
).unwrap();
const state = getState();
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
const { data } =
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const newSelectedImageDTO = data
? imagesAdapter.getSelectors().selectAll(data)[0]
: undefined;
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
} else {
dispatch(imageSelected(null));
}
dispatch(isModalOpenChanged(false));
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imagesUsage.some((i) => i.isCanvasImage)) {
dispatch(resetCanvas());
}
if (imagesUsage.some((i) => i.isControlNetImage)) {
dispatch(controlNetReset());
}
if (imagesUsage.some((i) => i.isInitialImage)) {
dispatch(clearInitialImage());
}
if (imagesUsage.some((i) => i.isNodesImage)) {
dispatch(nodeEditorReset());
}
} catch {
// no-op
}
},
});
};
/** /**
* Called when the actual delete request is sent to the server * Called when the actual delete request is sent to the server
*/ */

View File

@@ -6,7 +6,10 @@ import {
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import {
imageSelected,
imagesAddedToBatch,
} from 'features/gallery/store/gallerySlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@@ -24,32 +27,19 @@ export const addImageDroppedListener = () => {
const log = logger('images'); const log = logger('images');
const { activeData, overData } = action.payload; const { activeData, overData } = action.payload;
if (activeData.payloadType === 'IMAGE_DTO') { log.debug({ activeData, overData }, 'Image or selection dropped');
log.debug({ activeData, overData }, 'Image dropped');
} else if (activeData.payloadType === 'IMAGE_DTOS') {
log.debug(
{ activeData, overData },
`Images (${activeData.payload.imageDTOs.length}) dropped`
);
} else {
log.debug({ activeData, overData }, `Unknown payload dropped`);
}
/** // set current image
* Image dropped on current image
*/
if ( if (
overData.actionType === 'SET_CURRENT_IMAGE' && overData.actionType === 'SET_CURRENT_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO activeData.payload.imageDTO
) { ) {
dispatch(imageSelected(activeData.payload.imageDTO)); dispatch(imageSelected(activeData.payload.imageDTO.image_name));
return; return;
} }
/** // set initial image
* Image dropped on initial image
*/
if ( if (
overData.actionType === 'SET_INITIAL_IMAGE' && overData.actionType === 'SET_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
@@ -59,9 +49,27 @@ export const addImageDroppedListener = () => {
return; return;
} }
/** // add image to batch
* Image dropped on ControlNet if (
*/ overData.actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(imagesAddedToBatch([activeData.payload.imageDTO.image_name]));
return;
}
// add multiple images to batch
if (
overData.actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'IMAGE_NAMES'
) {
dispatch(imagesAddedToBatch(activeData.payload.image_names));
return;
}
// set control image
if ( if (
overData.actionType === 'SET_CONTROLNET_IMAGE' && overData.actionType === 'SET_CONTROLNET_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
@@ -77,9 +85,7 @@ export const addImageDroppedListener = () => {
return; return;
} }
/** // set canvas image
* Image dropped on Canvas
*/
if ( if (
overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' && overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
@@ -89,9 +95,7 @@ export const addImageDroppedListener = () => {
return; return;
} }
/** // set nodes image
* Image dropped on node image field
*/
if ( if (
overData.actionType === 'SET_NODES_IMAGE' && overData.actionType === 'SET_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
@@ -108,36 +112,61 @@ export const addImageDroppedListener = () => {
return; return;
} }
/** // set multiple nodes images (single image handler)
* TODO if (
* Image selection dropped on node image collection field overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
*/ activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { fieldName, nodeId } = overData.context;
dispatch(
fieldValueChanged({
nodeId,
fieldName,
value: [activeData.payload.imageDTO],
})
);
return;
}
// // set multiple nodes images (multiple images handler)
// if ( // if (
// overData.actionType === 'SET_MULTI_NODES_IMAGE' && // overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
// activeData.payloadType === 'IMAGE_DTO' && // activeData.payloadType === 'IMAGE_NAMES'
// activeData.payload.imageDTO
// ) { // ) {
// const { fieldName, nodeId } = overData.context; // const { fieldName, nodeId } = overData.context;
// dispatch( // dispatch(
// fieldValueChanged({ // imageCollectionFieldValueChanged({
// nodeId, // nodeId,
// fieldName, // fieldName,
// value: [activeData.payload.imageDTO], // value: activeData.payload.image_names.map((image_name) => ({
// image_name,
// })),
// }) // })
// ); // );
// return; // return;
// } // }
/** // add image to board
* Image dropped on user board
*/
if ( if (
overData.actionType === 'ADD_TO_BOARD' && overData.actionType === 'MOVE_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO activeData.payload.imageDTO
) { ) {
const { imageDTO } = activeData.payload; const { imageDTO } = activeData.payload;
const { boardId } = overData.context; const { boardId } = overData.context;
// image was droppe on the "NoBoardBoard"
if (!boardId) {
dispatch(
imagesApi.endpoints.removeImageFromBoard.initiate({
imageDTO,
})
);
return;
}
// image was dropped on a user board
dispatch( dispatch(
imagesApi.endpoints.addImageToBoard.initiate({ imagesApi.endpoints.addImageToBoard.initiate({
imageDTO, imageDTO,
@@ -147,58 +176,67 @@ export const addImageDroppedListener = () => {
return; return;
} }
/** // // add gallery selection to board
* Image dropped on 'none' board // if (
*/ // overData.actionType === 'MOVE_BOARD' &&
if ( // activeData.payloadType === 'IMAGE_NAMES' &&
overData.actionType === 'REMOVE_FROM_BOARD' && // overData.context.boardId
activeData.payloadType === 'IMAGE_DTO' && // ) {
activeData.payload.imageDTO // console.log('adding gallery selection to board');
) { // const board_id = overData.context.boardId;
const { imageDTO } = activeData.payload; // dispatch(
dispatch( // boardImagesApi.endpoints.addManyBoardImages.initiate({
imagesApi.endpoints.removeImageFromBoard.initiate({ // board_id,
imageDTO, // image_names: activeData.payload.image_names,
}) // })
); // );
return; // return;
} // }
/** // // remove gallery selection from board
* Multiple images dropped on user board // if (
*/ // overData.actionType === 'MOVE_BOARD' &&
if ( // activeData.payloadType === 'IMAGE_NAMES' &&
overData.actionType === 'ADD_TO_BOARD' && // overData.context.boardId === null
activeData.payloadType === 'IMAGE_DTOS' && // ) {
activeData.payload.imageDTOs // console.log('removing gallery selection to board');
) { // dispatch(
const { imageDTOs } = activeData.payload; // boardImagesApi.endpoints.deleteManyBoardImages.initiate({
const { boardId } = overData.context; // image_names: activeData.payload.image_names,
dispatch( // })
imagesApi.endpoints.addImagesToBoard.initiate({ // );
imageDTOs, // return;
board_id: boardId, // }
})
);
return;
}
/** // // add batch selection to board
* Multiple images dropped on 'none' board // if (
*/ // overData.actionType === 'MOVE_BOARD' &&
if ( // activeData.payloadType === 'IMAGE_NAMES' &&
overData.actionType === 'REMOVE_FROM_BOARD' && // overData.context.boardId
activeData.payloadType === 'IMAGE_DTOS' && // ) {
activeData.payload.imageDTOs // const board_id = overData.context.boardId;
) { // dispatch(
const { imageDTOs } = activeData.payload; // boardImagesApi.endpoints.addManyBoardImages.initiate({
dispatch( // board_id,
imagesApi.endpoints.removeImagesFromBoard.initiate({ // image_names: activeData.payload.image_names,
imageDTOs, // })
}) // );
); // return;
return; // }
}
// // remove batch selection from board
// if (
// overData.actionType === 'MOVE_BOARD' &&
// activeData.payloadType === 'IMAGE_NAMES' &&
// overData.context.boardId === null
// ) {
// dispatch(
// boardImagesApi.endpoints.deleteManyBoardImages.initiate({
// image_names: activeData.payload.image_names,
// })
// );
// return;
// }
}, },
}); });
}; };

View File

@@ -1,32 +1,37 @@
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions'; import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions';
import { selectImageUsage } from 'features/deleteImageModal/store/selectors'; import { selectImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
import { import {
imagesToDeleteSelected, imageToDeleteSelected,
isModalOpenChanged, isModalOpenChanged,
} from 'features/deleteImageModal/store/slice'; } from 'features/imageDeletion/store/imageDeletionSlice';
import { startAppListening } from '..'; import { startAppListening } from '..';
export const addImageToDeleteSelectedListener = () => { export const addImageToDeleteSelectedListener = () => {
startAppListening({ startAppListening({
actionCreator: imagesToDeleteSelected, actionCreator: imageToDeleteSelected,
effect: async (action, { dispatch, getState }) => { effect: async (action, { dispatch, getState }) => {
const imageDTOs = action.payload; const imageDTO = action.payload;
const state = getState(); const state = getState();
const { shouldConfirmOnDelete } = state.system; const { shouldConfirmOnDelete } = state.system;
const imagesUsage = selectImageUsage(getState()); const imageUsage = selectImageUsage(getState());
if (!imageUsage) {
// should never happen
return;
}
const isImageInUse = const isImageInUse =
imagesUsage.some((i) => i.isCanvasImage) || imageUsage.isCanvasImage ||
imagesUsage.some((i) => i.isInitialImage) || imageUsage.isInitialImage ||
imagesUsage.some((i) => i.isControlNetImage) || imageUsage.isControlNetImage ||
imagesUsage.some((i) => i.isNodesImage); imageUsage.isNodesImage;
if (shouldConfirmOnDelete || isImageInUse) { if (shouldConfirmOnDelete || isImageInUse) {
dispatch(isModalOpenChanged(true)); dispatch(isModalOpenChanged(true));
return; return;
} }
dispatch(imageDeletionConfirmed({ imageDTOs, imagesUsage })); dispatch(imageDeletionConfirmed({ imageDTO, imageUsage }));
}, },
}); });
}; };

View File

@@ -2,13 +2,14 @@ import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { imagesAddedToBatch } from 'features/gallery/store/gallerySlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards'; import { boardsApi } from 'services/api/endpoints/boards';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imagesApi } from '../../../../../services/api/endpoints/images'; import { imagesApi } from '../../../../../services/api/endpoints/images';
import { omit } from 'lodash-es';
const DEFAULT_UPLOADED_TOAST: UseToastOptions = { const DEFAULT_UPLOADED_TOAST: UseToastOptions = {
title: 'Image Uploaded', title: 'Image Uploaded',
@@ -40,7 +41,7 @@ export const addImageUploadedFulfilledListener = () => {
// default action - just upload and alert user // default action - just upload and alert user
if (postUploadAction?.type === 'TOAST') { if (postUploadAction?.type === 'TOAST') {
const { toastOptions } = postUploadAction; const { toastOptions } = postUploadAction;
if (!autoAddBoardId || autoAddBoardId === 'none') { if (!autoAddBoardId) {
dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions })); dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions }));
} else { } else {
// Add this image to the board // Add this image to the board
@@ -120,6 +121,17 @@ export const addImageUploadedFulfilledListener = () => {
); );
return; return;
} }
if (postUploadAction?.type === 'ADD_TO_BATCH') {
dispatch(imagesAddedToBatch([imageDTO.image_name]));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: 'Added to batch',
})
);
return;
}
}, },
}); });
}; };

View File

@@ -15,7 +15,7 @@ import {
setShouldUseSDXLRefiner, setShouldUseSDXLRefiner,
} from 'features/sdxl/store/sdxlSlice'; } from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
export const addModelsLoadedListener = () => { export const addModelsLoadedListener = () => {
@@ -144,9 +144,8 @@ export const addModelsLoadedListener = () => {
return; return;
} }
const firstModel = vaeModelsAdapter const firstModelId = action.payload.ids[0];
.getSelectors() const firstModel = action.payload.entities[firstModelId];
.selectAll(action.payload)[0];
if (!firstModel) { if (!firstModel) {
// No custom VAEs loaded at all; use the default // No custom VAEs loaded at all; use the default

View File

@@ -8,10 +8,9 @@ import {
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { progressImageSet } from 'features/system/store/systemSlice'; import { progressImageSet } from 'features/system/store/systemSlice';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesAdapter, imagesApi } from 'services/api/endpoints/images';
import { isImageOutput } from 'services/api/guards'; import { isImageOutput } from 'services/api/guards';
import { sessionCanceled } from 'services/api/thunks/session'; import { sessionCanceled } from 'services/api/thunks/session';
import { imagesAdapter } from 'services/api/util';
import { import {
appSocketInvocationComplete, appSocketInvocationComplete,
socketInvocationComplete, socketInvocationComplete,
@@ -68,7 +67,7 @@ export const addInvocationCompleteEventListener = () => {
*/ */
const { autoAddBoardId } = gallery; const { autoAddBoardId } = gallery;
if (autoAddBoardId && autoAddBoardId !== 'none') { if (autoAddBoardId) {
dispatch( dispatch(
imagesApi.endpoints.addImageToBoard.initiate({ imagesApi.endpoints.addImageToBoard.initiate({
board_id: autoAddBoardId, board_id: autoAddBoardId,
@@ -84,7 +83,10 @@ export const addInvocationCompleteEventListener = () => {
categories: IMAGE_CATEGORIES, categories: IMAGE_CATEGORIES,
}, },
(draft) => { (draft) => {
imagesAdapter.addOne(draft, imageDTO); const oldTotal = draft.total;
const newState = imagesAdapter.addOne(draft, imageDTO);
const delta = newState.total - oldTotal;
draft.total = draft.total + delta;
} }
) )
); );
@@ -92,8 +94,8 @@ export const addInvocationCompleteEventListener = () => {
dispatch( dispatch(
imagesApi.util.invalidateTags([ imagesApi.util.invalidateTags([
{ type: 'BoardImagesTotal', id: autoAddBoardId }, { type: 'BoardImagesTotal', id: autoAddBoardId ?? 'none' },
{ type: 'BoardAssetsTotal', id: autoAddBoardId }, { type: 'BoardAssetsTotal', id: autoAddBoardId ?? 'none' },
]) ])
); );
@@ -108,7 +110,7 @@ export const addInvocationCompleteEventListener = () => {
} else if (!autoAddBoardId) { } else if (!autoAddBoardId) {
dispatch(galleryViewChanged('images')); dispatch(galleryViewChanged('images'));
} }
dispatch(imageSelected(imageDTO)); dispatch(imageSelected(imageDTO.image_name));
} }
} }

Some files were not shown because too many files have changed in this diff Show More