Compare commits

..

1 Commits

Author SHA1 Message Date
Sergey Borisov
18956a6186 Expose seamless variables to node 2023-09-20 01:10:37 +03:00
496 changed files with 6055 additions and 20290 deletions

View File

@@ -167,23 +167,6 @@ and so you'll have access to the same python environment as the InvokeAI app.
This is _super_ handy.
#### Enabling Type-Checking with Pylance
We use python's typing system in InvokeAI. PR reviews will include checking that types are present and correct. We don't enforce types with `mypy` at this time, but that is on the horizon.
Using a code analysis tool to automatically type check your code (and types) is very important when writing with types. These tools provide immediate feedback in your editor when types are incorrect, and following their suggestions lead to fewer runtime bugs.
Pylance, installed at the beginning of this guide, is the de-facto python LSP (language server protocol). It provides type checking in the editor (among many other features). Once installed, you do need to enable type checking manually:
- Open a python file
- Look along the status bar in VSCode for `{ } Python`
- Click the `{ }`
- Turn type checking on - basic is fine
You'll now see red squiggly lines where type issues are detected. Hover your cursor over the indicated symbols to see what's wrong.
In 99% of cases when the type checker says there is a problem, there really is a problem, and you should take some time to understand and resolve what it is pointing out.
#### Debugging configs with `launch.json`
Debugging configs are managed in a `launch.json` file. Like most VSCode configs,
@@ -225,14 +208,6 @@ Now we can create the InvokeAI debugging configs:
"program": "scripts/invokeai-cli.py",
"justMyCode": true
},
{
"type": "chrome",
"request": "launch",
"name": "InvokeAI UI",
// You have to run the UI with `yarn dev` for this to work
"url": "http://localhost:5173",
"webRoot": "${workspaceFolder}/invokeai/frontend/web"
},
{
// Run tests
"name": "InvokeAI Test",
@@ -268,8 +243,7 @@ Now we can create the InvokeAI debugging configs:
You'll see these configs in the debugging configs drop down. Running them will
start InvokeAI with attached debugger, in the correct environment, and work just
like the normal app, though the UI debugger requires you to run the UI in dev
mode. See the [frontend guide](contribution_guides/contributingToFrontend.md) for setting that up.
like the normal app.
Enjoy debugging InvokeAI with ease (not that we have any bugs of course).

View File

@@ -38,9 +38,9 @@ There are two paths to making a development contribution:
If you need help, you can ask questions in the [#dev-chat](https://discord.com/channels/1020123559063990373/1049495067846524939) channel of the Discord.
For frontend related work, **@psychedelicious** is the best person to reach out to.
For frontend related work, **@pyschedelicious** is the best person to reach out to.
For backend related work, please reach out to **@blessedcoolant**, **@lstein**, **@StAlKeR7779** or **@psychedelicious**.
For backend related work, please reach out to **@blessedcoolant**, **@lstein**, **@StAlKeR7779** or **@pyschedelicious**.
## **What does the Code of Conduct mean for me?**

View File

@@ -10,4 +10,4 @@ When updating or creating documentation, please keep in mind InvokeAI is a tool
## Help & Questions
Please ping @imic or @hipsterusername in the [Discord](https://discord.com/channels/1020123559063990373/1049495067846524939) if you have any questions.
Please ping @imic1 or @hipsterusername in the [Discord](https://discord.com/channels/1020123559063990373/1049495067846524939) if you have any questions.

View File

@@ -159,7 +159,7 @@ groups in `invokeia.yaml`:
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
| `port` | `9090` | Network port number that the web server will listen on |
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_credentials | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |

View File

@@ -1,11 +1,13 @@
---
title: Control Adapters
title: ControlNet
---
# :material-loupe: Control Adapters
# :material-loupe: ControlNet
## ControlNet
ControlNet
ControlNet is a powerful set of features developed by the open-source
community (notably, Stanford researcher
[**@ilyasviel**](https://github.com/lllyasviel)) that allows you to
@@ -18,7 +20,7 @@ towards generating images that better fit your desired style or
outcome.
#### How it works
### How it works
ControlNet works by analyzing an input image, pre-processing that
image to identify relevant information that can be interpreted by each
@@ -28,7 +30,7 @@ composition, or other aspects of the image to better achieve a
specific result.
#### Models
### Models
InvokeAI provides access to a series of ControlNet models that provide
different effects or styles in your generated images. Currently
@@ -94,8 +96,6 @@ A model that generates normal maps from input images, allowing for more realisti
**Image Segmentation**:
A model that divides input images into segments or regions, each of which corresponds to a different object or part of the image. (More details coming soon)
**QR Code Monster**:
A model that helps generate creative QR codes that still scan. Can also be used to create images with text, logos or shapes within them.
**Openpose**:
The OpenPose control model allows for the identification of the general pose of a character by pre-processing an existing image with a clear human structure. With advanced options, Openpose can also detect the face or hands in the image.
@@ -120,7 +120,7 @@ With Pix2Pix, you can input an image into the controlnet, and then "instruct" th
Each of these models can be adjusted and combined with other ControlNet models to achieve different results, giving you even more control over your image generation process.
### Using ControlNet
## Using ControlNet
To use ControlNet, you can simply select the desired model and adjust both the ControlNet and Pre-processor settings to achieve the desired result. You can also use multiple ControlNet models at the same time, allowing you to achieve even more complex effects or styles in your generated images.
@@ -132,31 +132,3 @@ Weight - Strength of the Controlnet model applied to the generation for the sect
Start/End - 0 represents the start of the generation, 1 represents the end. The Start/end setting controls what steps during the generation process have the ControlNet applied.
Additionally, each ControlNet section can be expanded in order to manipulate settings for the image pre-processor that adjusts your uploaded image before using it in when you Invoke.
## IP-Adapter
[IP-Adapter](https://ip-adapter.github.io) is a tooling that allows for image prompt capabilities with text-to-image diffusion models. IP-Adapter works by analyzing the given image prompt to extract features, then passing those features to the UNet along with any other conditioning provided.
![IP-Adapter + T2I](https://github.com/tencent-ailab/IP-Adapter/raw/main/assets/demo/ip_adpter_plus_multi.jpg)
![IP-Adapter + IMG2IMG](https://github.com/tencent-ailab/IP-Adapter/blob/main/assets/demo/image-to-image.jpg)
#### Installation
There are several ways to install IP-Adapter models with an existing InvokeAI installation:
1. Through the command line interface launched from the invoke.sh / invoke.bat scripts, option [5] to download models.
2. Through the Model Manager UI with models from the *Tools* section of [www.models.invoke.ai](www.models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models.
3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders shouid be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder.
#### Using IP-Adapter
IP-Adapter can be used by navigating to the *Control Adapters* options and enabling IP-Adapter.
IP-Adapter requires an image to be used as the Image Prompt. It can also be used in conjunction with text prompts, Image-to-Image, Inpainting, Outpainting, ControlNets and LoRAs.
Each IP-Adapter has two settings that are applied to the IP-Adapter:
* Weight - Strength of the IP-Adapter model applied to the generation for the section, defined by start/end
* Start/End - 0 represents the start of the generation, 1 represents the end. The Start/end setting controls what steps during the generation process have the IP-Adapter applied.

View File

@@ -1,336 +0,0 @@
---
title: Command-line Utilities
---
# :material-file-document: Utilities
# Command-line Utilities
InvokeAI comes with several scripts that are accessible via the
command line. To access these commands, start the "developer's
console" from the launcher (`invoke.bat` menu item [8]). Users who are
familiar with Python can alternatively activate InvokeAI's virtual
environment (typically, but not necessarily `invokeai/.venv`).
In the developer's console, type the script's name to run it. To get a
synopsis of what a utility does and the command-line arguments it
accepts, pass it the `-h` argument, e.g.
```bash
invokeai-merge -h
```
## **invokeai-web**
This script launches the web server and is effectively identical to
selecting option [1] in the launcher. An advantage of launching the
server from the command line is that you can override any setting
configuration option in `invokeai.yaml` using like-named command-line
arguments. For example, to temporarily change the size of the RAM
cache to 7 GB, you can launch as follows:
```bash
invokeai-web --ram 7
```
## **invokeai-merge**
This is the model merge script, the same as launcher option [4]. Call
it with the `--gui` command-line argument to start the interactive
console-based GUI. Alternatively, you can run it non-interactively
using command-line arguments as illustrated in the example below which
merges models named `stable-diffusion-1.5` and `inkdiffusion` into a new model named
`my_new_model`:
```bash
invokeai-merge --force --base-model sd-1 --models stable-diffusion-1.5 inkdiffusion --merged_model_name my_new_model
```
## **invokeai-ti**
This is the textual inversion training script that is run by launcher
option [3]. Call it with `--gui` to run the interactive console-based
front end. It can also be run non-interactively. It has about a
zillion arguments, but a typical training session can be launched
with:
```bash
invokeai-ti --model stable-diffusion-1.5 \
--placeholder_token 'jello' \
--learnable_property object \
--num_train_epochs 50 \
--train_data_dir /path/to/training/images \
--output_dir /path/to/trained/model
```
(Note that \\ is the Linux/Mac long-line continuation character. Use ^
in Windows).
## **invokeai-install**
This is the console-based model install script that is run by launcher
option [5]. If called without arguments, it will launch the
interactive console-based interface. It can also be used
non-interactively to list, add and remove models as shown by these
examples:
* This will download and install three models from CivitAI, HuggingFace,
and local disk:
```bash
invokeai-install --add https://civitai.com/api/download/models/161302 ^
gsdf/Counterfeit-V3.0 ^
D:\Models\merge_model_two.safetensors
```
(Note that ^ is the Windows long-line continuation character. Use \\ on
Linux/Mac).
* This will list installed models of type `main`:
```bash
invokeai-model-install --list-models main
```
* This will delete the models named `voxel-ish` and `realisticVision`:
```bash
invokeai-model-install --delete voxel-ish realisticVision
```
## **invokeai-configure**
This is the console-based configure script that ran when InvokeAI was
first installed. You can run it again at any time to change the
configuration, repair a broken install.
Called without any arguments, `invokeai-configure` enters interactive
mode with two screens. The first screen is a form that provides access
to most of InvokeAI's configuration options. The second screen lets
you download, add, and delete models interactively. When you exit the
second screen, the script will add any missing "support models"
needed for core functionality, and any selected "sd weights" which are
the model checkpoint/diffusers files.
This behavior can be changed via a series of command-line
arguments. Here are some of the useful ones:
* `invokeai-configure --skip-sd-weights --skip-support-models`
This will run just the configuration part of the utility, skipping
downloading of support models and stable diffusion weights.
* `invokeai-configure --yes`
This will run the configure script non-interactively. It will set the
configuration options to their default values, install/repair support
models, and download the "recommended" set of SD models.
* `invokeai-configure --yes --default_only`
This will run the configure script non-interactively. In contrast to
the previous command, it will only download the default SD model,
Stable Diffusion v1.5
* `invokeai-configure --yes --default_only --skip-sd-weights`
This is similar to the previous command, but will not download any
SD models at all. It is usually used to repair a broken install.
By default, `invokeai-configure` runs on the currently active InvokeAI
root folder. To run it against a different root, pass it the `--root
</path/to/root>` argument.
Lastly, you can use `invokeai-configure` to create a working root
directory entirely from scratch. Assuming you wish to make a root directory
named `InvokeAI-New`, run this command:
```bash
invokeai-configure --root InvokeAI-New --yes --default_only
```
This will create a minimally functional root directory. You can now
launch the web server against it with `invokeai-web --root InvokeAI-New`.
## **invokeai-update**
This is the interactive console-based script that is run by launcher
menu item [9] to update to a new version of InvokeAI. It takes no
command-line arguments.
## **invokeai-metadata**
This is a script which takes a list of InvokeAI-generated images and
outputs their metadata in the same JSON format that you get from the
`</>` button in the Web GUI. For example:
```bash
$ invokeai-metadata ffe2a115-b492-493c-afff-7679aa034b50.png
ffe2a115-b492-493c-afff-7679aa034b50.png:
{
"app_version": "3.1.0",
"cfg_scale": 8.0,
"clip_skip": 0,
"controlnets": [],
"generation_mode": "sdxl_txt2img",
"height": 1024,
"loras": [],
"model": {
"base_model": "sdxl",
"model_name": "stable-diffusion-xl-base-1.0",
"model_type": "main"
},
"negative_prompt": "",
"negative_style_prompt": "",
"positive_prompt": "military grade sushi dinner for shock troopers",
"positive_style_prompt": "",
"rand_device": "cpu",
"refiner_cfg_scale": 7.5,
"refiner_model": {
"base_model": "sdxl-refiner",
"model_name": "sd_xl_refiner_1.0",
"model_type": "main"
},
"refiner_negative_aesthetic_score": 2.5,
"refiner_positive_aesthetic_score": 6.0,
"refiner_scheduler": "euler",
"refiner_start": 0.8,
"refiner_steps": 20,
"scheduler": "euler",
"seed": 387129902,
"steps": 25,
"width": 1024
}
```
You may list multiple files on the command line.
## **invokeai-import-images**
InvokeAI uses a database to store information about images it
generated, and just copying the image files from one InvokeAI root
directory to another does not automatically import those images into
the destination's gallery. This script allows you to bulk import
images generated by one instance of InvokeAI into a gallery maintained
by another. It also works on images generated by older versions of
InvokeAI, going way back to version 1.
This script has an interactive mode only. The following example shows
it in action:
```bash
$ invokeai-import-images
===============================================================================
This script will import images generated by earlier versions of
InvokeAI into the currently installed root directory:
/home/XXXX/invokeai-main
If this is not what you want to do, type ctrl-C now to cancel.
===============================================================================
= Configuration & Settings
Found invokeai.yaml file at /home/XXXX/invokeai-main/invokeai.yaml:
Database : /home/XXXX/invokeai-main/databases/invokeai.db
Outputs : /home/XXXX/invokeai-main/outputs/images
Use these paths for import (yes) or choose different ones (no) [Yn]:
Inputs: Specify absolute path containing InvokeAI .png images to import: /home/XXXX/invokeai-2.3/outputs/images/
Include files from subfolders recursively [yN]?
Options for board selection for imported images:
1) Select an existing board name. (found 4)
2) Specify a board name to create/add to.
3) Create/add to board named 'IMPORT'.
4) Create/add to board named 'IMPORT' with the current datetime string appended (.e.g IMPORT_20230919T203519Z).
5) Create/add to board named 'IMPORT' with a the original file app_version appended (.e.g IMPORT_2.2.5).
Specify desired board option: 3
===============================================================================
= Import Settings Confirmation
Database File Path : /home/XXXX/invokeai-main/databases/invokeai.db
Outputs/Images Directory : /home/XXXX/invokeai-main/outputs/images
Import Image Source Directory : /home/XXXX/invokeai-2.3/outputs/images/
Recurse Source SubDirectories : No
Count of .png file(s) found : 5785
Board name option specified : IMPORT
Database backup will be taken at : /home/XXXX/invokeai-main/databases/backup
Notes about the import process:
- Source image files will not be modified, only copied to the outputs directory.
- If the same file name already exists in the destination, the file will be skipped.
- If the same file name already has a record in the database, the file will be skipped.
- Invoke AI metadata tags will be updated/written into the imported copy only.
- On the imported copy, only Invoke AI known tags (latest and legacy) will be retained (dream, sd-metadata, invokeai, invokeai_metadata)
- A property 'imported_app_version' will be added to metadata that can be viewed in the UI's metadata viewer.
- The new 3.x InvokeAI outputs folder structure is flat so recursively found source imges will all be placed into the single outputs/images folder.
Do you wish to continue with the import [Yn] ?
Making DB Backup at /home/lstein/invokeai-main/databases/backup/backup-20230919T203519Z-invokeai.db...Done!
===============================================================================
Importing /home/XXXX/invokeai-2.3/outputs/images/17d09907-297d-4db3-a18a-60b337feac66.png
... (5785 more lines) ...
===============================================================================
= Import Complete - Elpased Time: 0.28 second(s)
Source File(s) : 5785
Total Imported : 5783
Skipped b/c file already exists on disk : 1
Skipped b/c file already exists in db : 0
Errors during import : 1
```
## **invokeai-db-maintenance**
This script helps maintain the integrity of your InvokeAI database by
finding and fixing three problems that can arise over time:
1. An image was manually deleted from the outputs directory, leaving a
dangling image record in the InvokeAI database. This will cause a
black image to appear in the gallery. This is an "orphaned database
image record." The script can fix this by running a "clean"
operation on the database, removing the orphaned entries.
2. An image is present in the outputs directory but there is no
corresponding entry in the database. This can happen when the image
is added manually to the outputs directory, or if a crash occurred
after the image was generated but before the database was
completely updated. The symptom is that the image is present in the
outputs folder but doesn't appear in the InvokeAI gallery. This is
called an "orphaned image file." The script can fix this problem by
running an "archive" operation in which orphaned files are moved
into a directory named `outputs/images-archive`. If you wish, you
can then run `invokeai-image-import` to reimport these images back
into the database.
3. The thumbnail for an image is missing, again causing a black
gallery thumbnail. This is fixed by running the "thumbnaiils"
operation, which simply regenerates and re-registers the missing
thumbnail.
You can find and fix all three of these problems in a single go by
executing this command:
```bash
invokeai-db-maintenance --operation all
```
Or you can run just the clean and thumbnail operations like this:
```bash
invokeai-db-maintenance -operation clean, thumbnail
```
If called without any arguments, the script will ask you which
operations you wish to perform.
## **invokeai-migrate3**
This script will migrate settings and models (but not images!) from an
InvokeAI v2.3 root folder to an InvokeAI 3.X folder. Call it with the
source and destination root folders like this:
```bash
invokeai-migrate3 --from ~/invokeai-2.3 --to invokeai-3.1.1
```
Both directories must previously have been properly created and
initialized by `invokeai-configure`. If you wish to migrate the images
contained in the older root as well, you can use the
`invokeai-image-migrate` script described earlier.
---
Copyright (c) 2023, Lincoln Stein and the InvokeAI Development Team

View File

@@ -51,9 +51,6 @@ Prevent InvokeAI from displaying unwanted racy images.
### * [Controlling Logging](LOGGING.md)
Control how InvokeAI logs status messages.
### * [Command-line Utilities](UTILITIES.md)
A list of the command-line utilities available with InvokeAI.
<!-- OUT OF DATE
### * [Miscellaneous](OTHER.md)
Run InvokeAI on Google Colab, generate images with repeating patterns,

View File

@@ -147,7 +147,6 @@ Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
### InvokeAI Configuration
- [Guide to InvokeAI Runtime Settings](features/CONFIGURATION.md)
- [Database Maintenance and other Command Line Utilities](features/UTILITIES.md)
## :octicons-log-16: Important Changes Since Version 2.3

View File

@@ -296,18 +296,8 @@ code for InvokeAI. For this to work, you will need to install the
on your system, please see the [Git Installation
Guide](https://github.com/git-guides/install-git)
You will also need to install the [frontend development toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/docs/contributing/contribution_guides/contributingToFrontend.md).
If you have a "normal" installation, you should create a totally separate virtual environment for the git-based installation, else the two may interfere.
> **Why do I need the frontend toolchain**?
>
> The InvokeAI project uses trunk-based development. That means our `main` branch is the development branch, and releases are tags on that branch. Because development is very active, we don't keep an updated build of the UI in `main` - we only build it for production releases.
>
> That means that between releases, to have a functioning application when running directly from the repo, you will need to run the UI in dev mode or build it regularly (any time the UI code changes).
1. Create a fork of the InvokeAI repository through the GitHub UI or [this link](https://github.com/invoke-ai/InvokeAI/fork)
2. From the command line, run this command:
1. From the command line, run this command:
```bash
git clone https://github.com/<your_github_username>/InvokeAI.git
```
@@ -315,10 +305,10 @@ If you have a "normal" installation, you should create a totally separate virtua
This will create a directory named `InvokeAI` and populate it with the
full source code from your fork of the InvokeAI repository.
3. Activate the InvokeAI virtual environment as per step (4) of the manual
2. Activate the InvokeAI virtual environment as per step (4) of the manual
installation protocol (important!)
4. Enter the InvokeAI repository directory and run one of these
3. Enter the InvokeAI repository directory and run one of these
commands, based on your GPU:
=== "CUDA (NVidia)"
@@ -344,15 +334,11 @@ installation protocol (important!)
Be sure to pass `-e` (for an editable install) and don't forget the
dot ("."). It is part of the command.
5. Install the [frontend toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/docs/contributing/contribution_guides/contributingToFrontend.md) and do a production build of the UI as described.
6. You can now run `invokeai` and its related commands. The code will be
You can now run `invokeai` and its related commands. The code will be
read from the repository, so that you can edit the .py source files
and watch the code's behavior change.
When you pull in new changes to the repo, be sure to re-build the UI.
7. If you wish to contribute to the InvokeAI project, you are
4. If you wish to contribute to the InvokeAI project, you are
encouraged to establish a GitHub account and "fork"
https://github.com/invoke-ai/InvokeAI into your own copy of the
repository. You can then use GitHub functions to create and submit

View File

@@ -171,16 +171,3 @@ subfolders and organize them as you wish.
The location of the autoimport directories are controlled by settings
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).
### Installing models that live in HuggingFace subfolders
On rare occasions you may need to install a diffusers-style model that
lives in a subfolder of a HuggingFace repo id. In this event, simply
add ":_subfolder-name_" to the end of the repo id. For example, if the
repo id is "monster-labs/control_v1p_sd15_qrcode_monster" and the model
you wish to fetch lives in a subfolder named "v2", then the repo id to
pass to the various model installers should be
```
monster-labs/control_v1p_sd15_qrcode_monster:v2
```

View File

@@ -4,12 +4,12 @@ The workflow editor is a blank canvas allowing for the use of individual functio
If you're not familiar with Diffusion, take a look at our [Diffusion Overview.](../help/diffusion.md) Understanding how diffusion works will enable you to more easily use the Workflow Editor and build workflows to suit your needs.
## Features
## UI Features
### Linear View
The Workflow Editor allows you to create a UI for your workflow, to make it easier to iterate on your generations.
To add an input to the Linear UI, right click on the input label and select "Add to Linear View".
To add an input to the Linear UI, right click on the input and select "Add to Linear View".
The Linear UI View will also be part of the saved workflow, allowing you share workflows and enable other to use them, regardless of complexity.
@@ -25,10 +25,6 @@ Any node or input field can be renamed in the workflow editor. If the input fiel
* Backspace/Delete to delete a node
* Shift+Click to drag and select multiple nodes
### Node Caching
Nodes have a "Use Cache" option in their footer. This allows for performance improvements by using the previously cached values during the workflow processing.
## Important Concepts

View File

@@ -8,7 +8,19 @@ To download a node, simply download the `.py` node file from the link and add it
To use a community workflow, download the the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
--------------------------------
## Community Nodes
### FaceTools
**Description:** FaceTools is a collection of nodes created to manipulate faces as you would in Unified Canvas. It includes FaceMask, FaceOff, and FacePlace. FaceMask autodetects a face in the image using MediaPipe and creates a mask from it. FaceOff similarly detects a face, then takes the face off of the image by adding a square bounding box around it and cropping/scaling it. FacePlace puts the bounded face image from FaceOff back onto the original image. Using these nodes with other inpainting node(s), you can put new faces on existing things, put new things around existing faces, and work closer with a face as a bounded image. Additionally, you can supply X and Y offset values to scale/change the shape of the mask for finer control on FaceMask and FaceOff. See GitHub repository below for usage examples.
**Node Link:** https://github.com/ymgenesis/FaceTools/
**FaceMask Output Examples**
![5cc8abce-53b0-487a-b891-3bf94dcc8960](https://github.com/invoke-ai/InvokeAI/assets/25252829/43f36d24-1429-4ab1-bd06-a4bedfe0955e)
![b920b710-1882-49a0-8d02-82dff2cca907](https://github.com/invoke-ai/InvokeAI/assets/25252829/7660c1ed-bf7d-4d0a-947f-1fc1679557ba)
![71a91805-fda5-481c-b380-264665703133](https://github.com/invoke-ai/InvokeAI/assets/25252829/f8f6a2ee-2b68-4482-87da-b90221d5c3e2)
--------------------------------
### Ideal Size
@@ -31,52 +43,6 @@ To use a community workflow, download the the `.json` node graph file and load i
**Node Link:** https://github.com/JPPhoto/image-picker-node
--------------------------------
### Thresholding
**Description:** This node generates masks for highlights, midtones, and shadows given an input image. You can optionally specify a blur for the lookup table used in making those masks from the source image.
**Node Link:** https://github.com/JPPhoto/thresholding-node
**Examples**
Input:
![image](https://github.com/invoke-ai/InvokeAI/assets/34005131/c88ada13-fb3d-484c-a4fe-947b44712632){: style="height:512px;width:512px"}
Highlights/Midtones/Shadows:
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/727021c1-36ff-4ec8-90c8-105e00de986d" style="width: 30%" />
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/0b721bfc-f051-404e-b905-2f16b824ddfe" style="width: 30%" />
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/04c1297f-1c88-42b6-a7df-dd090b976286" style="width: 30%" />
Highlights/Midtones/Shadows (with LUT blur enabled):
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/19aa718a-70c1-4668-8169-d68f4bd13771" style="width: 30%" />
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/0a440e43-697f-4d17-82ee-f287467df0a5" style="width: 30%" />
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/0701fd0f-2ca7-4fe2-8613-2b52547bafce" style="width: 30%" />
--------------------------------
### Halftone
**Description**: Halftone converts the source image to grayscale and then performs halftoning. CMYK Halftone converts the image to CMYK and applies a per-channel halftoning to make the source image look like a magazine or newspaper. For both nodes, you can specify angles and halftone dot spacing.
**Node Link:** https://github.com/JPPhoto/halftone-node
**Example**
Input:
![image](https://github.com/invoke-ai/InvokeAI/assets/34005131/fd5efb9f-4355-4409-a1c2-c1ca99e0cab4){: style="height:512px;width:512px"}
Halftone Output:
![image](https://github.com/invoke-ai/InvokeAI/assets/34005131/7e606f29-e68f-4d46-b3d5-97f799a4ec2f){: style="height:512px;width:512px"}
CMYK Halftone Output:
![image](https://github.com/invoke-ai/InvokeAI/assets/34005131/c59c578f-db8e-4d66-8c66-2851752d75ea){: style="height:512px;width:512px"}
--------------------------------
### Retroize
@@ -111,7 +77,7 @@ Generated Prompt: An enchanted weapon will be usable by any character regardless
**Example Node Graph:** https://github.com/helix4u/load_video_frame/blob/main/Example_Workflow.json
**Output Example:**
=======
![Example animation](https://github.com/helix4u/load_video_frame/blob/main/testmp4_embed_converted.gif)
[Full mp4 of Example Output test.mp4](https://github.com/helix4u/load_video_frame/blob/main/test.mp4)
@@ -155,6 +121,18 @@ To be imported, an .obj must use triangulated meshes, so make sure to enable tha
**Example Usage:**
![depth from obj usage graph](https://raw.githubusercontent.com/dwringer/depth-from-obj-node/main/depth_from_obj_usage.jpg)
--------------------------------
### Enhance Image (simple adjustments)
**Description:** Boost or reduce color saturation, contrast, brightness, sharpness, or invert colors of any image at any stage with this simple wrapper for pillow [PIL]'s ImageEnhance module.
Color inversion is toggled with a simple switch, while each of the four enhancer modes are activated by entering a value other than 1 in each corresponding input field. Values less than 1 will reduce the corresponding property, while values greater than 1 will enhance it.
**Node Link:** https://github.com/dwringer/image-enhance-node
**Example Usage:**
![enhance image usage graph](https://raw.githubusercontent.com/dwringer/image-enhance-node/main/image_enhance_usage.jpg)
--------------------------------
### Generative Grammar-Based Prompt Nodes
@@ -175,28 +153,16 @@ This includes 3 Nodes:
**Description:** This is a pack of nodes for composing masks and images, including a simple text mask creator and both image and latent offset nodes. The offsets wrap around, so these can be used in conjunction with the Seamless node to progressively generate centered on different parts of the seamless tiling.
This includes 15 Nodes:
- *Adjust Image Hue Plus* - Rotate the hue of an image in one of several different color spaces.
- *Blend Latents/Noise (Masked)* - Use a mask to blend part of one latents tensor [including Noise outputs] into another. Can be used to "renoise" sections during a multi-stage [masked] denoising process.
- *Enhance Image* - Boost or reduce color saturation, contrast, brightness, sharpness, or invert colors of any image at any stage with this simple wrapper for pillow [PIL]'s ImageEnhance module.
- *Equivalent Achromatic Lightness* - Calculates image lightness accounting for Helmholtz-Kohlrausch effect based on a method described by High, Green, and Nussbaum (2023).
- *Text to Mask (Clipseg)* - Input a prompt and an image to generate a mask representing areas of the image matched by the prompt.
- *Text to Mask Advanced (Clipseg)* - Output up to four prompt masks combined with logical "and", logical "or", or as separate channels of an RGBA image.
- *Image Layer Blend* - Perform a layered blend of two images using alpha compositing. Opacity of top layer is selectable, with optional mask and several different blend modes/color spaces.
This includes 4 Nodes:
- *Text Mask (simple 2D)* - create and position a white on black (or black on white) line of text using any font locally available to Invoke.
- *Image Compositor* - Take a subject from an image with a flat backdrop and layer it on another image using a chroma key or flood select background removal.
- *Image Dilate or Erode* - Dilate or expand a mask (or any image!). This is equivalent to an expand/contract operation.
- *Image Value Thresholds* - Clip an image to pure black/white beyond specified thresholds.
- *Offset Latents* - Offset a latents tensor in the vertical and/or horizontal dimensions, wrapping it around.
- *Offset Image* - Offset an image in the vertical and/or horizontal dimensions, wrapping it around.
- *Rotate/Flip Image* - Rotate an image in degrees clockwise/counterclockwise about its center, optionally resizing the image boundaries to fit, or flipping it about the vertical and/or horizontal axes.
- *Shadows/Highlights/Midtones* - Extract three masks (with adjustable hard or soft thresholds) representing shadows, midtones, and highlights regions of an image.
- *Text Mask (simple 2D)* - create and position a white on black (or black on white) line of text using any font locally available to Invoke.
**Node Link:** https://github.com/dwringer/composition-nodes
**Nodes and Output Examples:**
![composition nodes usage graph](https://raw.githubusercontent.com/dwringer/composition-nodes/main/composition_pack_overview.jpg)
**Example Usage:**
![composition nodes usage graph](https://raw.githubusercontent.com/dwringer/composition-nodes/main/composition_nodes_usage.jpg)
--------------------------------
### Size Stepper Nodes
@@ -230,70 +196,6 @@ Results after using the depth controlnet
--------------------------------
### Prompt Tools
**Description:** A set of InvokeAI nodes that add general prompt manipulation tools. These where written to accompany the PromptsFromFile node and other prompt generation nodes.
1. PromptJoin - Joins to prompts into one.
2. PromptReplace - performs a search and replace on a prompt. With the option of using regex.
3. PromptSplitNeg - splits a prompt into positive and negative using the old V2 method of [] for negative.
4. PromptToFile - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option.
5. PTFieldsCollect - Converts image generation fields into a Json format string that can be passed to Prompt to file.
6. PTFieldsExpand - Takes Json string and converts it to individual generation parameters This can be fed from the Prompts to file node.
7. PromptJoinThree - Joins 3 prompt together.
8. PromptStrength - This take a string and float and outputs another string in the format of (string)strength like the weighted format of compel.
9. PromptStrengthCombine - This takes a collection of prompt strength strings and outputs a string in the .and() or .blend() format that can be fed into a proper prompt node.
See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md
**Node Link:** https://github.com/skunkworxdark/Prompt-tools-nodes
--------------------------------
### XY Image to Grid and Images to Grids nodes
**Description:** Image to grid nodes and supporting tools.
1. "Images To Grids" node - Takes a collection of images and creates a grid(s) of images. If there are more images than the size of a single grid then mutilple grids will be created until it runs out of images.
2. "XYImage To Grid" node - Converts a collection of XYImages into a labeled Grid of images. The XYImages collection has to be built using the supporoting nodes. See example node setups for more details.
See full docs here: https://github.com/skunkworxdark/XYGrid_nodes/edit/main/README.md
**Node Link:** https://github.com/skunkworxdark/XYGrid_nodes
--------------------------------
### Image to Character Art Image Node's
**Description:** Group of nodes to convert an input image into ascii/unicode art Image
**Node Link:** https://github.com/mickr777/imagetoasciiimage
**Output Examples**
<img src="https://github.com/invoke-ai/InvokeAI/assets/115216705/8e061fcc-9a2c-4fa9-bcc7-c0f7b01e9056" width="300" />
<img src="https://github.com/mickr777/imagetoasciiimage/assets/115216705/3c4990eb-2f42-46b9-90f9-0088b939dc6a" width="300" /></br>
<img src="https://github.com/mickr777/imagetoasciiimage/assets/115216705/fee7f800-a4a8-41e2-a66b-c66e4343307e" width="300" />
<img src="https://github.com/mickr777/imagetoasciiimage/assets/115216705/1d9c1003-a45f-45c2-aac7-46470bb89330" width="300" />
--------------------------------
### Grid to Gif
**Description:** One node that turns a grid image into an image colletion, one node that turns an image collection into a gif
**Node Link:** https://github.com/mildmisery/invokeai-GridToGifNode/blob/main/GridToGif.py
**Example Node Graph:** https://github.com/mildmisery/invokeai-GridToGifNode/blob/main/Grid%20to%20Gif%20Example%20Workflow.json
**Output Examples**
<img src="https://raw.githubusercontent.com/mildmisery/invokeai-GridToGifNode/main/input.png" width="300" />
<img src="https://raw.githubusercontent.com/mildmisery/invokeai-GridToGifNode/main/output.gif" width="300" />
--------------------------------
### Example Node Template
**Description:** This node allows you to do super cool things with InvokeAI.

View File

@@ -1,6 +1,6 @@
# List of Default Nodes
The table below contains a list of the default nodes shipped with InvokeAI and their descriptions.
The table below contains a list of the default nodes shipped with InvokeAI and their descriptions.
| Node <img width=160 align="right"> | Function |
|: ---------------------------------- | :--------------------------------------------------------------------------------------|
@@ -17,12 +17,11 @@ The table below contains a list of the default nodes shipped with InvokeAI and t
|Conditioning Primitive | A conditioning tensor primitive value|
|Content Shuffle Processor | Applies content shuffle processing to image|
|ControlNet | Collects ControlNet info to pass to other nodes|
|OpenCV Inpaint | Simple inpaint using opencv.|
|Denoise Latents | Denoises noisy latents to decodable images|
|Divide Integers | Divides two numbers|
|Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator|
|[FaceMask](./detailedNodes/faceTools.md#facemask) | Generates masks for faces in an image to use with Inpainting|
|[FaceIdentifier](./detailedNodes/faceTools.md#faceidentifier) | Identifies and labels faces in an image|
|[FaceOff](./detailedNodes/faceTools.md#faceoff) | Creates a new image that is a scaled bounding box with a mask on the face for Inpainting|
|Upscale (RealESRGAN) | Upscales an image using RealESRGAN.|
|Float Math | Perform basic math operations on two floats|
|Float Primitive Collection | A collection of float primitive values|
|Float Primitive | A float primitive value|
@@ -77,7 +76,6 @@ The table below contains a list of the default nodes shipped with InvokeAI and t
|ONNX Prompt (Raw) | A node to process inputs and produce outputs. May use dependency injection in __init__ to receive providers.|
|ONNX Text to Latents | Generates latents from conditionings.|
|ONNX Model Loader | Loads a main model, outputting its submodels.|
|OpenCV Inpaint | Simple inpaint using opencv.|
|Openpose Processor | Applies Openpose processing to image|
|PIDI Processor | Applies PIDI processing to image|
|Prompts from File | Loads prompts from a text file|
@@ -99,6 +97,5 @@ The table below contains a list of the default nodes shipped with InvokeAI and t
|String Primitive | A string primitive value|
|Subtract Integers | Subtracts two numbers|
|Tile Resample Processor | Tile resampler processor|
|Upscale (RealESRGAN) | Upscales an image using RealESRGAN.|
|VAE Loader | Loads a VAE model, outputting a VaeLoaderOutput|
|Zoe (Depth) Processor | Applies Zoe depth processing to image|

View File

@@ -1,154 +0,0 @@
# Face Nodes
## FaceOff
FaceOff mimics a user finding a face in an image and resizing the bounding box
around the head in Canvas.
Enter a face ID (found with FaceIdentifier) to choose which face to mask.
Just as you would add more context inside the bounding box by making it larger
in Canvas, the node gives you a padding input (in pixels) which will
simultaneously add more context, and increase the resolution of the bounding box
so the face remains the same size inside it.
The "Minimum Confidence" input defaults to 0.5 (50%), and represents a pass/fail
threshold a detected face must reach for it to be processed. Lowering this value
may help if detection is failing. If the detected masks are imperfect and stray
too far outside/inside of faces, the node gives you X & Y offsets to shrink/grow
the masks by a multiplier.
FaceOff will output the face in a bounded image, taking the face off of the
original image for input into any node that accepts image inputs. The node also
outputs a face mask with the dimensions of the bounded image. The X & Y outputs
are for connecting to the X & Y inputs of the Paste Image node, which will place
the bounded image back on the original image using these coordinates.
###### Inputs/Outputs
| Input | Description |
| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Image | Image for face detection |
| Face ID | The face ID to process, numbered from 0. Multiple faces not supported. Find a face's ID with FaceIdentifier node. |
| Minimum Confidence | Minimum confidence for face detection (lower if detection is failing) |
| X Offset | X-axis offset of the mask |
| Y Offset | Y-axis offset of the mask |
| Padding | All-axis padding around the mask in pixels |
| Chunk | Chunk (or divide) the image into sections to greatly improve face detection success. Defaults to off, but will activate if no faces are detected normally. Activate to chunk by default. |
| Output | Description |
| ------------- | ------------------------------------------------ |
| Bounded Image | Original image bound, cropped, and resized |
| Width | The width of the bounded image in pixels |
| Height | The height of the bounded image in pixels |
| Mask | The output mask |
| X | The x coordinate of the bounding box's left side |
| Y | The y coordinate of the bounding box's top side |
## FaceMask
FaceMask mimics a user drawing masks on faces in an image in Canvas.
The "Face IDs" input allows the user to select specific faces to be masked.
Leave empty to detect and mask all faces, or a comma-separated list for a
specific combination of faces (ex: `1,2,4`). A single integer will detect and
mask that specific face. Find face IDs with the FaceIdentifier node.
The "Minimum Confidence" input defaults to 0.5 (50%), and represents a pass/fail
threshold a detected face must reach for it to be processed. Lowering this value
may help if detection is failing.
If the detected masks are imperfect and stray too far outside/inside of faces,
the node gives you X & Y offsets to shrink/grow the masks by a multiplier. All
masks shrink/grow together by the X & Y offset values.
By default, masks are created to change faces. When masks are inverted, they
change surrounding areas, protecting faces.
###### Inputs/Outputs
| Input | Description |
| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Image | Image for face detection |
| Face IDs | Comma-separated list of face ids to mask eg '0,2,7'. Numbered from 0. Leave empty to mask all. Find face IDs with FaceIdentifier node. |
| Minimum Confidence | Minimum confidence for face detection (lower if detection is failing) |
| X Offset | X-axis offset of the mask |
| Y Offset | Y-axis offset of the mask |
| Chunk | Chunk (or divide) the image into sections to greatly improve face detection success. Defaults to off, but will activate if no faces are detected normally. Activate to chunk by default. |
| Invert Mask | Toggle to invert the face mask |
| Output | Description |
| ------ | --------------------------------- |
| Image | The original image |
| Width | The width of the image in pixels |
| Height | The height of the image in pixels |
| Mask | The output face mask |
## FaceIdentifier
FaceIdentifier outputs an image with detected face IDs printed in white numbers
onto each face.
Face IDs can then be used in FaceMask and FaceOff to selectively mask all, a
specific combination, or single faces.
The FaceIdentifier output image is generated for user reference, and isn't meant
to be passed on to other image-processing nodes.
The "Minimum Confidence" input defaults to 0.5 (50%), and represents a pass/fail
threshold a detected face must reach for it to be processed. Lowering this value
may help if detection is failing. If an image is changed in the slightest, run
it through FaceIdentifier again to get updated FaceIDs.
###### Inputs/Outputs
| Input | Description |
| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Image | Image for face detection |
| Minimum Confidence | Minimum confidence for face detection (lower if detection is failing) |
| Chunk | Chunk (or divide) the image into sections to greatly improve face detection success. Defaults to off, but will activate if no faces are detected normally. Activate to chunk by default. |
| Output | Description |
| ------ | ------------------------------------------------------------------------------------------------ |
| Image | The original image with small face ID numbers printed in white onto each face for user reference |
| Width | The width of the original image in pixels |
| Height | The height of the original image in pixels |
## Tips
- If not all target faces are being detected, activate Chunk to bypass full
image face detection and greatly improve detection success.
- Final results will vary between full-image detection and chunking for faces
that are detectable by both due to the nature of the process. Try either to
your taste.
- Be sure Minimum Confidence is set the same when using FaceIdentifier with
FaceOff/FaceMask.
- For FaceOff, use the color correction node before faceplace to correct edges
being noticeable in the final image (see example screenshot).
- Non-inpainting models may struggle to paint/generate correctly around faces.
- If your face won't change the way you want it to no matter what you change,
consider that the change you're trying to make is too much at that resolution.
For example, if an image is only 512x768 total, the face might only be 128x128
or 256x256, much smaller than the 512x512 your SD1.5 model was probably
trained on. Try increasing the resolution of the image by upscaling or
resizing, add padding to increase the bounding box's resolution, or use an
image where the face takes up more pixels.
- If the resulting face seems out of place pasted back on the original image
(ie. too large, not proportional), add more padding on the FaceOff node to
give inpainting more context. Context and good prompting are important to
keeping things proportional.
- If you find the mask is too big/small and going too far outside/inside the
area you want to affect, adjust the x & y offsets to shrink/grow the mask area
- Use a higher denoise start value to resemble aspects of the original face or
surroundings. Denoise start = 0 & denoise end = 1 will make something new,
while denoise start = 0.50 & denoise end = 1 will be 50% old and 50% new.
- mediapipe isn't good at detecting faces with lots of face paint, hair covering
the face, etc. Anything that obstructs the face will likely result in no faces
being detected.
- If you find your face isn't being detected, try lowering the minimum
confidence value from 0.5. This could result in false positives, however
(random areas being detected as faces and masked).
- After altering an image and wanting to process a different face in the newly
altered image, run the altered image through FaceIdentifier again to see the
new Face IDs. MediaPipe will most likely detect faces in a different order
after an image has been changed in the slightest.

View File

@@ -9,6 +9,5 @@ If you're interested in finding more workflows, checkout the [#share-your-workfl
* [SD1.5 / SD2 Text to Image](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/Text_to_Image.json)
* [SDXL Text to Image](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/SDXL_Text_to_Image.json)
* [SDXL (with Refiner) Text to Image](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/SDXL_Text_to_Image.json)
* [Tiled Upscaling with ControlNet](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/ESRGAN_img2img_upscale w_Canny_ControlNet.json)
* [FaceMask](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/FaceMask.json)
* [FaceOff with 2x Face Scaling](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/FaceOff_FaceScale2x.json)
* [Tiled Upscaling with ControlNet](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/ESRGAN_img2img_upscale w_Canny_ControlNet.json)ß

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -332,7 +332,6 @@ class InvokeAiInstance:
Configure the InvokeAI runtime directory
"""
auto_install = False
# set sys.argv to a consistent state
new_argv = [sys.argv[0]]
for i in range(1, len(sys.argv)):
@@ -341,17 +340,13 @@ class InvokeAiInstance:
new_argv.append(el)
new_argv.append(sys.argv[i + 1])
elif el in ["-y", "--yes", "--yes-to-all"]:
auto_install = True
new_argv.append(el)
sys.argv = new_argv
import messages
import requests # to catch download exceptions
from messages import introduction
auto_install = auto_install or messages.user_wants_auto_configuration()
if auto_install:
sys.argv.append("--yes")
else:
messages.introduction()
introduction()
from invokeai.frontend.install.invokeai_configure import invokeai_configure

View File

@@ -7,7 +7,7 @@ import os
import platform
from pathlib import Path
from prompt_toolkit import HTML, prompt
from prompt_toolkit import prompt
from prompt_toolkit.completion import PathCompleter
from prompt_toolkit.validation import Validator
from rich import box, print
@@ -65,50 +65,17 @@ def confirm_install(dest: Path) -> bool:
if dest.exists():
print(f":exclamation: Directory {dest} already exists :exclamation:")
dest_confirmed = Confirm.ask(
":stop_sign: (re)install in this location?",
":stop_sign: Are you sure you want to (re)install in this location?",
default=False,
)
else:
print(f"InvokeAI will be installed in {dest}")
dest_confirmed = Confirm.ask("Use this location?", default=True)
dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False)
console.line()
return dest_confirmed
def user_wants_auto_configuration() -> bool:
"""Prompt the user to choose between manual and auto configuration."""
console.rule("InvokeAI Configuration Section")
console.print(
Panel(
Group(
"\n".join(
[
"Libraries are installed and InvokeAI will now set up its root directory and configuration. Choose between:",
"",
" * AUTOMATIC configuration: install reasonable defaults and a minimal set of starter models.",
" * MANUAL configuration: manually inspect and adjust configuration options and pick from a larger set of starter models.",
"",
"Later you can fine tune your configuration by selecting option [6] 'Change InvokeAI startup options' from the invoke.bat/invoke.sh launcher script.",
]
),
),
box=box.MINIMAL,
padding=(1, 1),
)
)
choice = (
prompt(
HTML("Choose <b>&lt;a&gt;</b>utomatic or <b>&lt;m&gt;</b>anual configuration [a/m] (a): "),
validator=Validator.from_callable(
lambda n: n == "" or n.startswith(("a", "A", "m", "M")), error_message="Please select 'a' or 'm'"
),
)
or "a"
)
return choice.lower().startswith("a")
def dest_path(dest=None) -> Path:
"""
Prompt the user for the destination path and create the path

View File

@@ -17,10 +17,9 @@ echo 6. Change InvokeAI startup options
echo 7. Re-run the configure script to fix a broken install or to complete a major upgrade
echo 8. Open the developer console
echo 9. Update InvokeAI
echo 10. Run the InvokeAI image database maintenance script
echo 11. Command-line help
echo 10. Command-line help
echo Q - Quit
set /P choice="Please enter 1-11, Q: [1] "
set /P choice="Please enter 1-10, Q: [1] "
if not defined choice set choice=1
IF /I "%choice%" == "1" (
echo Starting the InvokeAI browser-based UI..
@@ -59,11 +58,8 @@ IF /I "%choice%" == "1" (
echo Running invokeai-update...
python -m invokeai.frontend.install.invokeai_update
) ELSE IF /I "%choice%" == "10" (
echo Running the db maintenance script...
python .venv\Scripts\invokeai-db-maintenance.exe
) ELSE IF /I "%choice%" == "11" (
echo Displaying command line help...
python .venv\Scripts\invokeai-web.exe --help %*
python .venv\Scripts\invokeai.exe --help %*
pause
exit /b
) ELSE IF /I "%choice%" == "q" (

View File

@@ -97,13 +97,13 @@ do_choice() {
;;
10)
clear
printf "Running the db maintenance script\n"
invokeai-db-maintenance --root ${INVOKEAI_ROOT}
printf "Command-line help\n"
invokeai --help
;;
11)
"HELP 1")
clear
printf "Command-line help\n"
invokeai-web --help
invokeai --help
;;
*)
clear
@@ -125,10 +125,7 @@ do_dialog() {
6 "Change InvokeAI startup options"
7 "Re-run the configure script to fix a broken install or to complete a major upgrade"
8 "Open the developer console"
9 "Update InvokeAI"
10 "Run the InvokeAI image database maintenance script"
11 "Command-line help"
)
9 "Update InvokeAI")
choice=$(dialog --clear \
--backtitle "\Zb\Zu\Z3InvokeAI" \
@@ -160,10 +157,9 @@ do_line_input() {
printf "7: Re-run the configure script to fix a broken install\n"
printf "8: Open the developer console\n"
printf "9: Update InvokeAI\n"
printf "10: Run the InvokeAI image database maintenance script\n"
printf "11: Command-line help\n"
printf "10: Command-line help\n"
printf "Q: Quit\n\n"
read -p "Please enter 1-11, Q: [1] " yn
read -p "Please enter 1-10, Q: [1] " yn
choice=${yn:='1'}
do_choice $choice
clear

View File

@@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import sqlite3
from logging import Logger
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
@@ -10,10 +9,7 @@ from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@@ -29,7 +25,6 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto
from ..services.model_manager_service import ModelManagerService
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.thread import lock
from .events import FastAPIEventService
@@ -49,7 +44,7 @@ def check_internet() -> bool:
return False
logger = InvokeAILogger.get_logger()
logger = InvokeAILogger.getLogger()
class ApiDependencies:
@@ -68,32 +63,22 @@ class ApiDependencies:
output_folder = config.output_path
# TODO: build a file/path manager?
if config.use_memory_db:
db_location = ":memory:"
else:
db_path = config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path)
logger.info(f"Using database at {db_location}")
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
if config.log_sql:
db_conn.set_trace_callback(print)
db_conn.execute("PRAGMA foreign_keys = ON;")
db_path = config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
conn=db_conn, table_name="graph_executions", lock=lock
filename=db_location, table_name="graph_executions"
)
urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
boards = BoardService(
services=BoardServiceDependencies(
@@ -135,29 +120,18 @@ class ApiDependencies:
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
session_processor=DefaultSessionProcessor(),
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
)
create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services)
try:
lock.acquire()
db_conn.execute("VACUUM;")
db_conn.commit()
logger.info("Cleaned database")
finally:
lock.release()
@staticmethod
def shutdown():
if ApiDependencies.invoker:

View File

@@ -7,7 +7,6 @@ from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
@@ -104,43 +103,3 @@ async def set_log_level(
"""Sets the log verbosity level"""
ApiDependencies.invoker.services.logger.setLevel(level)
return LogLevel(ApiDependencies.invoker.services.logger.level)
@app_router.delete(
"/invocation_cache",
operation_id="clear_invocation_cache",
responses={200: {"description": "The operation was successful"}},
)
async def clear_invocation_cache() -> None:
"""Clears the invocation cache"""
ApiDependencies.invoker.services.invocation_cache.clear()
@app_router.put(
"/invocation_cache/enable",
operation_id="enable_invocation_cache",
responses={200: {"description": "The operation was successful"}},
)
async def enable_invocation_cache() -> None:
"""Clears the invocation cache"""
ApiDependencies.invoker.services.invocation_cache.enable()
@app_router.put(
"/invocation_cache/disable",
operation_id="disable_invocation_cache",
responses={200: {"description": "The operation was successful"}},
)
async def disable_invocation_cache() -> None:
"""Clears the invocation cache"""
ApiDependencies.invoker.services.invocation_cache.disable()
@app_router.get(
"/invocation_cache/status",
operation_id="get_invocation_cache_status",
responses={200: {"model": InvocationCacheStatus}},
)
async def get_invocation_cache_status() -> InvocationCacheStatus:
"""Clears the invocation cache"""
return ApiDependencies.invoker.services.invocation_cache.get_status()

View File

@@ -146,8 +146,7 @@ async def update_model(
async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints",
default=None,
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
),
) -> ImportModelResponse:
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""

View File

@@ -1,247 +0,0 @@
from typing import Optional
from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelByBatchIDsResult,
ClearResult,
EnqueueBatchResult,
EnqueueGraphResult,
PruneResult,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.models import CursorPaginatedResults
from ...services.graph import Graph
from ..dependencies import ApiDependencies
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
class SessionQueueAndProcessorStatus(BaseModel):
"""The overall status of session queue and processor"""
queue: SessionQueueStatus
processor: SessionProcessorStatus
@session_queue_router.post(
"/{queue_id}/enqueue_graph",
operation_id="enqueue_graph",
responses={
201: {"model": EnqueueGraphResult},
},
)
async def enqueue_graph(
queue_id: str = Path(description="The queue id to perform this operation on"),
graph: Graph = Body(description="The graph to enqueue"),
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
) -> EnqueueGraphResult:
"""Enqueues a graph for single execution."""
return ApiDependencies.invoker.services.session_queue.enqueue_graph(queue_id=queue_id, graph=graph, prepend=prepend)
@session_queue_router.post(
"/{queue_id}/enqueue_batch",
operation_id="enqueue_batch",
responses={
201: {"model": EnqueueBatchResult},
},
)
async def enqueue_batch(
queue_id: str = Path(description="The queue id to perform this operation on"),
batch: Batch = Body(description="Batch to process"),
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
return ApiDependencies.invoker.services.session_queue.enqueue_batch(queue_id=queue_id, batch=batch, prepend=prepend)
@session_queue_router.get(
"/{queue_id}/list",
operation_id="list_queue_items",
responses={
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
},
)
async def list_queue_items(
queue_id: str = Path(description="The queue id to perform this operation on"),
limit: int = Query(default=50, description="The number of items to fetch"),
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
priority: int = Query(default=0, description="The pagination cursor priority"),
) -> CursorPaginatedResults[SessionQueueItemDTO]:
"""Gets all queue items (without graphs)"""
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id, limit=limit, status=status, cursor=cursor, priority=priority
)
@session_queue_router.put(
"/{queue_id}/processor/resume",
operation_id="resume",
responses={200: {"model": SessionProcessorStatus}},
)
async def resume(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Resumes session processor"""
return ApiDependencies.invoker.services.session_processor.resume()
@session_queue_router.put(
"/{queue_id}/processor/pause",
operation_id="pause",
responses={200: {"model": SessionProcessorStatus}},
)
async def Pause(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Pauses session processor"""
return ApiDependencies.invoker.services.session_processor.pause()
@session_queue_router.put(
"/{queue_id}/cancel_by_batch_ids",
operation_id="cancel_by_batch_ids",
responses={200: {"model": CancelByBatchIDsResult}},
)
async def cancel_by_batch_ids(
queue_id: str = Path(description="The queue id to perform this operation on"),
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
) -> CancelByBatchIDsResult:
"""Immediately cancels all queue items from the given batch ids"""
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@session_queue_router.put(
"/{queue_id}/clear",
operation_id="clear",
responses={
200: {"model": ClearResult},
},
)
async def clear(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> ClearResult:
"""Clears the queue entirely, immediately canceling the currently-executing session"""
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
@session_queue_router.put(
"/{queue_id}/prune",
operation_id="prune",
responses={
200: {"model": PruneResult},
},
)
async def prune(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> PruneResult:
"""Prunes all completed or errored queue items"""
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
@session_queue_router.get(
"/{queue_id}/current",
operation_id="get_current_queue_item",
responses={
200: {"model": Optional[SessionQueueItem]},
},
)
async def get_current_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the currently execution queue item"""
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
@session_queue_router.get(
"/{queue_id}/next",
operation_id="get_next_queue_item",
responses={
200: {"model": Optional[SessionQueueItem]},
},
)
async def get_next_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the next queue item, without executing it"""
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
@session_queue_router.get(
"/{queue_id}/status",
operation_id="get_queue_status",
responses={
200: {"model": SessionQueueAndProcessorStatus},
},
)
async def get_queue_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
@session_queue_router.get(
"/{queue_id}/b/{batch_id}/status",
operation_id="get_batch_status",
responses={
200: {"model": BatchStatus},
},
)
async def get_batch_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
batch_id: str = Path(description="The batch to get the status of"),
) -> BatchStatus:
"""Gets the status of the session queue"""
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
@session_queue_router.get(
"/{queue_id}/i/{item_id}",
operation_id="get_queue_item",
responses={
200: {"model": SessionQueueItem},
},
)
async def get_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
item_id: int = Path(description="The queue item to get"),
) -> SessionQueueItem:
"""Gets a queue item"""
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
@session_queue_router.put(
"/{queue_id}/i/{item_id}/cancel",
operation_id="cancel_queue_item",
responses={
200: {"model": SessionQueueItem},
},
)
async def cancel_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
item_id: int = Path(description="The queue item to cancel"),
) -> SessionQueueItem:
"""Deletes a queue item"""
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)

View File

@@ -23,14 +23,12 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
200: {"model": GraphExecutionState},
400: {"description": "Invalid json"},
},
deprecated=True,
)
async def create_session(
queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph"""
session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
session = ApiDependencies.invoker.create_execution_state(graph)
return session
@@ -38,7 +36,6 @@ async def create_session(
"/",
operation_id="list_sessions",
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
deprecated=True,
)
async def list_sessions(
page: int = Query(default=0, description="The page of results to get"),
@@ -60,7 +57,6 @@ async def list_sessions(
200: {"model": GraphExecutionState},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def get_session(
session_id: str = Path(description="The id of the session to get"),
@@ -81,7 +77,6 @@ async def get_session(
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def add_node(
session_id: str = Path(description="The id of the session"),
@@ -114,7 +109,6 @@ async def add_node(
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def update_node(
session_id: str = Path(description="The id of the session"),
@@ -148,7 +142,6 @@ async def update_node(
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def delete_node(
session_id: str = Path(description="The id of the session"),
@@ -179,7 +172,6 @@ async def delete_node(
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def add_edge(
session_id: str = Path(description="The id of the session"),
@@ -211,7 +203,6 @@ async def add_edge(
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def delete_edge(
session_id: str = Path(description="The id of the session"),
@@ -250,10 +241,8 @@ async def delete_edge(
400: {"description": "The session has no invocations ready to invoke"},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def invoke_session(
queue_id: str = Query(description="The id of the queue to associate the session with"),
session_id: str = Path(description="The id of the session to invoke"),
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
) -> Response:
@@ -265,7 +254,7 @@ async def invoke_session(
if session.is_complete():
raise HTTPException(status_code=400)
ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
ApiDependencies.invoker.invoke(session, invoke_all=all)
return Response(status_code=202)
@@ -273,7 +262,6 @@ async def invoke_session(
"/{session_id}/invoke",
operation_id="cancel_session_invoke",
responses={202: {"description": "The invocation is canceled"}},
deprecated=True,
)
async def cancel_session_invoke(
session_id: str = Path(description="The id of the session to cancel"),

View File

@@ -1,41 +0,0 @@
from typing import Optional
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
from fastapi import Body
from fastapi.routing import APIRouter
from pydantic import BaseModel
from pyparsing import ParseException
utilities_router = APIRouter(prefix="/v1/utilities", tags=["utilities"])
class DynamicPromptsResponse(BaseModel):
prompts: list[str]
error: Optional[str] = None
@utilities_router.post(
"/dynamicprompts",
operation_id="parse_dynamicprompts",
responses={
200: {"model": DynamicPromptsResponse},
},
)
async def parse_dynamicprompts(
prompt: str = Body(description="The prompt to parse with dynamicprompts"),
max_prompts: int = Body(default=1000, description="The max number of prompts to generate"),
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
) -> DynamicPromptsResponse:
"""Creates a batch process"""
try:
error: Optional[str] = None
if combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(prompt, max_prompts=max_prompts)
else:
generator = RandomPromptGenerator()
prompts = generator.generate(prompt, num_images=max_prompts)
except ParseException as e:
prompts = [prompt]
error = str(e)
return DynamicPromptsResponse(prompts=prompts if prompts else [""], error=error)

View File

@@ -3,35 +3,34 @@
from fastapi import FastAPI
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from socketio import ASGIApp, AsyncServer
from fastapi_socketio import SocketManager
from ..services.events import EventServiceBase
class SocketIO:
__sio: AsyncServer
__app: ASGIApp
__sio: SocketManager
def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="socket.io")
app.mount("/ws", self.__app)
self.__sio = SocketManager(app=app)
self.__sio.on("subscribe", handler=self._handle_sub)
self.__sio.on("unsubscribe", handler=self._handle_unsub)
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
async def _handle_queue_event(self, event: Event):
async def _handle_session_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["queue_id"],
room=event[1]["data"]["graph_execution_state_id"],
)
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
if "queue_id" in data:
self.__sio.enter_room(sid, data["queue_id"])
async def _handle_sub(self, sid, data, *args, **kwargs):
if "session" in data:
self.__sio.enter_room(sid, data["session"])
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
if "queue_id" in data:
self.__sio.enter_room(sid, data["queue_id"])
# @app.sio.on('unsubscribe')
async def _handle_unsub(self, sid, data, *args, **kwargs):
if "session" in data:
self.__sio.leave_room(sid, data["session"])

View File

@@ -1,3 +1,4 @@
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from .services.config import InvokeAIAppConfig
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
@@ -8,6 +9,7 @@ app_config.parse_args()
if True: # hack to make flake8 happy with imports coming after setting up the config
import asyncio
import logging
import mimetypes
import socket
from inspect import signature
@@ -31,7 +33,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
from .api.routers import app_info, board_images, boards, images, models, sessions
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
@@ -40,9 +42,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.get_logger(config=app_config)
logger = InvokeAILogger.getLogger(config=app_config)
# fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
@@ -92,8 +92,6 @@ async def shutdown_event():
app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
@@ -104,8 +102,6 @@ app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
@@ -224,7 +220,7 @@ def invoke_api():
exc_info=e,
)
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
port = find_port(app_config.port)
if port != app_config.port:
@@ -243,7 +239,7 @@ def invoke_api():
# replace uvicorn's loggers with InvokeAI's for consistent appearance
for logname in ["uvicorn.access", "uvicorn"]:
log = InvokeAILogger.get_logger(logname)
log = logging.getLogger(logname)
log.handlers.clear()
for ch in logger.handlers:
log.addHandler(ch)

View File

@@ -1,18 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from .services.config import InvokeAIAppConfig
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
config = InvokeAIAppConfig.get_config()
config.parse_args()
if True: # hack to make flake8 happy with imports coming after setting up the config
import argparse
import re
import shlex
import sqlite3
import sys
import time
from typing import Optional, Union, get_type_hints
@@ -59,9 +58,8 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().get_logger(config=config)
logger = InvokeAILogger().getLogger(config=config)
class CliCommand(BaseModel):
@@ -251,18 +249,19 @@ def invoke_cli():
db_location = config.db_path
db_location.parent.mkdir(parents=True, exist_ok=True)
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
logger.info(f'InvokeAI database location is "{db_location}"')
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
boards = BoardService(
services=BoardServiceDependencies(
@@ -304,13 +303,12 @@ def invoke_cli():
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
configuration=config,
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
)
system_graphs = create_system_graphs(services.graph_library)

View File

@@ -88,12 +88,6 @@ class FieldDescriptions:
num_1 = "The first number"
num_2 = "The second number"
mask = "The mask to use for the operation"
board = "The board to save the image to"
image = "The image to process"
tile_size = "Tile size"
inclusive_low = "The inclusive low value"
exclusive_high = "The exclusive high value"
decimal_places = "The number of decimal places to round to"
class Input(str, Enum):
@@ -179,7 +173,6 @@ class UIType(str, Enum):
WorkflowField = "WorkflowField"
IsIntermediate = "IsIntermediate"
MetadataField = "MetadataField"
BoardField = "BoardField"
# endregion
@@ -426,27 +419,12 @@ class UIConfigBase(BaseModel):
class InvocationContext:
"""Initialized and provided to on execution of invocations."""
services: InvocationServices
graph_execution_state_id: str
queue_id: str
queue_item_id: int
queue_batch_id: str
def __init__(
self,
services: InvocationServices,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
):
def __init__(self, services: InvocationServices, graph_execution_state_id: str):
self.services = services
self.graph_execution_state_id = graph_execution_state_id
self.queue_id = queue_id
self.queue_item_id = queue_item_id
self.queue_batch_id = queue_batch_id
class BaseInvocationOutput(BaseModel):
@@ -544,9 +522,6 @@ class BaseInvocation(ABC, BaseModel):
return signature(cls.invoke).return_annotation
class Config:
validate_assignment = True
validate_all = True
@staticmethod
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
uiconfig = getattr(model_class, "UIConfig", None)
@@ -595,29 +570,7 @@ class BaseInvocation(ABC, BaseModel):
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
elif _input == Input.Any:
raise MissingInputException(self.__fields__["type"].default, field_name)
# skip node cache codepath if it's disabled
if context.services.configuration.node_cache_size == 0:
return self.invoke(context)
output: BaseInvocationOutput
if self.use_cache:
key = context.services.invocation_cache.create_key(self)
cached_value = context.services.invocation_cache.get(key)
if cached_value is None:
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
output = self.invoke(context)
context.services.invocation_cache.save(key, output)
return output
else:
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
return cached_value
else:
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
return self.invoke(context)
def get_type(self) -> str:
return self.__fields__["type"].default
return self.invoke(context)
id: str = Field(
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
@@ -630,7 +583,6 @@ class BaseInvocation(ABC, BaseModel):
description="The workflow to save with the image",
ui_type=UIType.WorkflowField,
)
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
@validator("workflow", pre=True)
def validate_workflow_is_json(cls, v):
@@ -654,7 +606,6 @@ def invocation(
tags: Optional[list[str]] = None,
category: Optional[str] = None,
version: Optional[str] = None,
use_cache: Optional[bool] = True,
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
"""
Adds metadata to an invocation.
@@ -663,8 +614,6 @@ def invocation(
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
:param Optional[list[str]] tags: Adds tags to the invocation. Invocations may be searched for by their tags. Defaults to None.
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
"""
def wrapper(cls: Type[GenericBaseInvocation]) -> Type[GenericBaseInvocation]:
@@ -689,8 +638,6 @@ def invocation(
except ValueError as e:
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
cls.UIConfig.version = version
if use_cache is not None:
cls.__fields__["use_cache"].default = use_cache
# Add the invocation type to the pydantic model of the invocation
invocation_type_annotation = Literal[invocation_type] # type: ignore

View File

@@ -56,7 +56,6 @@ class RangeOfSizeInvocation(BaseInvocation):
tags=["range", "integer", "random", "collection"],
category="collections",
version="1.0.0",
use_cache=False,
)
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""

View File

@@ -38,6 +38,7 @@ from .baseinvocation import (
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
@@ -99,7 +100,7 @@ class ControlNetInvocation(BaseInvocation):
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet"
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
@@ -559,33 +560,3 @@ class SamDetectorReproducibleColors(SamDetector):
img[:, :] = ann_color
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
return np.array(final_img, dtype=np.uint8)
@invocation(
"color_map_image_processor",
title="Color Map Processor",
tags=["controlnet"],
category="controlnet",
version="1.0.0",
)
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a color map from the provided image"""
color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size)
def run_processor(self, image: Image.Image):
image = image.convert("RGB")
image = np.array(image, dtype=np.uint8)
height, width = image.shape[:2]
width_tile_size = min(self.color_map_tile_size, width)
height_tile_size = min(self.color_map_tile_size, height)
color_map = cv2.resize(
image,
(width // width_tile_size, height // height_tile_size),
interpolation=cv2.INTER_CUBIC,
)
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
color_map = Image.fromarray(color_map)
return color_map

View File

@@ -1,692 +0,0 @@
import math
import re
from pathlib import Path
from typing import Optional, TypedDict
import cv2
import numpy as np
from mediapipe.python.solutions.face_mesh import FaceMesh # type: ignore[import]
from PIL import Image, ImageDraw, ImageFilter, ImageFont, ImageOps
from PIL.Image import Image as ImageType
from pydantic import validator
import invokeai.assets.fonts as font_assets
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
InputField,
InvocationContext,
OutputField,
invocation,
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ResourceOrigin
@invocation_output("face_mask_output")
class FaceMaskOutput(ImageOutput):
"""Base class for FaceMask output"""
mask: ImageField = OutputField(description="The output mask")
@invocation_output("face_off_output")
class FaceOffOutput(ImageOutput):
"""Base class for FaceOff Output"""
mask: ImageField = OutputField(description="The output mask")
x: int = OutputField(description="The x coordinate of the bounding box's left side")
y: int = OutputField(description="The y coordinate of the bounding box's top side")
class FaceResultData(TypedDict):
image: ImageType
mask: ImageType
x_center: float
y_center: float
mesh_width: int
mesh_height: int
class FaceResultDataWithId(FaceResultData):
face_id: int
class ExtractFaceData(TypedDict):
bounded_image: ImageType
bounded_mask: ImageType
x_min: int
y_min: int
x_max: int
y_max: int
class FaceMaskResult(TypedDict):
image: ImageType
mask: ImageType
def create_white_image(w: int, h: int) -> ImageType:
return Image.new("L", (w, h), color=255)
def create_black_image(w: int, h: int) -> ImageType:
return Image.new("L", (w, h), color=0)
FONT_SIZE = 32
FONT_STROKE_WIDTH = 4
def prepare_faces_list(
face_result_list: list[FaceResultData],
) -> list[FaceResultDataWithId]:
"""Deduplicates a list of faces, adding IDs to them."""
deduped_faces: list[FaceResultData] = []
if len(face_result_list) == 0:
return list()
for candidate in face_result_list:
should_add = True
candidate_x_center = candidate["x_center"]
candidate_y_center = candidate["y_center"]
for face in deduped_faces:
face_center_x = face["x_center"]
face_center_y = face["y_center"]
face_radius_w = face["mesh_width"] / 2
face_radius_h = face["mesh_height"] / 2
# Determine if the center of the candidate_face is inside the ellipse of the added face
# p < 1 -> Inside
# p = 1 -> Exactly on the ellipse
# p > 1 -> Outside
p = (math.pow((candidate_x_center - face_center_x), 2) / math.pow(face_radius_w, 2)) + (
math.pow((candidate_y_center - face_center_y), 2) / math.pow(face_radius_h, 2)
)
if p < 1: # Inside of the already-added face's radius
should_add = False
break
if should_add is True:
deduped_faces.append(candidate)
sorted_faces = sorted(deduped_faces, key=lambda x: x["y_center"])
sorted_faces = sorted(sorted_faces, key=lambda x: x["x_center"])
# add face_id for reference
sorted_faces_with_ids: list[FaceResultDataWithId] = []
face_id_counter = 0
for face in sorted_faces:
sorted_faces_with_ids.append(
FaceResultDataWithId(
**face,
face_id=face_id_counter,
)
)
face_id_counter += 1
return sorted_faces_with_ids
def generate_face_box_mask(
context: InvocationContext,
minimum_confidence: float,
x_offset: float,
y_offset: float,
pil_image: ImageType,
chunk_x_offset: int = 0,
chunk_y_offset: int = 0,
draw_mesh: bool = True,
check_bounds: bool = True,
) -> list[FaceResultData]:
result = []
mask_pil = None
# Convert the PIL image to a NumPy array.
np_image = np.array(pil_image, dtype=np.uint8)
# Check if the input image has four channels (RGBA).
if np_image.shape[2] == 4:
# Convert RGBA to RGB by removing the alpha channel.
np_image = np_image[:, :, :3]
# Create a FaceMesh object for face landmark detection and mesh generation.
face_mesh = FaceMesh(
max_num_faces=999,
min_detection_confidence=minimum_confidence,
min_tracking_confidence=minimum_confidence,
)
# Detect the face landmarks and mesh in the input image.
results = face_mesh.process(np_image)
# Check if any face is detected.
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
# Search for the face_id in the detected faces.
for face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
# Get the bounding box of the face mesh.
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]
x_min, x_max = min(x_coordinates), max(x_coordinates)
y_min, y_max = min(y_coordinates), max(y_coordinates)
# Calculate the width and height of the face mesh.
mesh_width = int((x_max - x_min) * np_image.shape[1])
mesh_height = int((y_max - y_min) * np_image.shape[0])
# Get the center of the face.
x_center = np.mean([landmark.x * np_image.shape[1] for landmark in face_landmarks.landmark])
y_center = np.mean([landmark.y * np_image.shape[0] for landmark in face_landmarks.landmark])
face_landmark_points = np.array(
[
[landmark.x * np_image.shape[1], landmark.y * np_image.shape[0]]
for landmark in face_landmarks.landmark
]
)
# Apply the scaling offsets to the face landmark points with a multiplier.
scale_multiplier = 0.2
x_center = np.mean(face_landmark_points[:, 0])
y_center = np.mean(face_landmark_points[:, 1])
if draw_mesh:
x_scaled = face_landmark_points[:, 0] + scale_multiplier * x_offset * (
face_landmark_points[:, 0] - x_center
)
y_scaled = face_landmark_points[:, 1] + scale_multiplier * y_offset * (
face_landmark_points[:, 1] - y_center
)
convex_hull = cv2.convexHull(np.column_stack((x_scaled, y_scaled)).astype(np.int32))
# Generate a binary face mask using the face mesh.
mask_image = np.ones(np_image.shape[:2], dtype=np.uint8) * 255
cv2.fillConvexPoly(mask_image, convex_hull, 0)
# Convert the binary mask image to a PIL Image.
init_mask_pil = Image.fromarray(mask_image, mode="L")
w, h = init_mask_pil.size
mask_pil = create_white_image(w + chunk_x_offset, h + chunk_y_offset)
mask_pil.paste(init_mask_pil, (chunk_x_offset, chunk_y_offset))
left_side = x_center - mesh_width
right_side = x_center + mesh_width
top_side = y_center - mesh_height
bottom_side = y_center + mesh_height
im_width, im_height = pil_image.size
over_w = im_width * 0.1
over_h = im_height * 0.1
if not check_bounds or (
(left_side >= -over_w)
and (right_side < im_width + over_w)
and (top_side >= -over_h)
and (bottom_side < im_height + over_h)
):
x_center = float(x_center)
y_center = float(y_center)
face = FaceResultData(
image=pil_image,
mask=mask_pil or create_white_image(*pil_image.size),
x_center=x_center + chunk_x_offset,
y_center=y_center + chunk_y_offset,
mesh_width=mesh_width,
mesh_height=mesh_height,
)
result.append(face)
else:
context.services.logger.info("FaceTools --> Face out of bounds, ignoring.")
return result
def extract_face(
context: InvocationContext,
image: ImageType,
face: FaceResultData,
padding: int,
) -> ExtractFaceData:
mask = face["mask"]
center_x = face["x_center"]
center_y = face["y_center"]
mesh_width = face["mesh_width"]
mesh_height = face["mesh_height"]
# Determine the minimum size of the square crop
min_size = min(mask.width, mask.height)
# Calculate the crop boundaries for the output image and mask.
mesh_width += 128 + padding # add pixels to account for mask variance
mesh_height += 128 + padding # add pixels to account for mask variance
crop_size = min(
max(mesh_width, mesh_height, 128), min_size
) # Choose the smaller of the two (given value or face mask size)
if crop_size > 128:
crop_size = (crop_size + 7) // 8 * 8 # Ensure crop side is multiple of 8
# Calculate the actual crop boundaries within the bounds of the original image.
x_min = int(center_x - crop_size / 2)
y_min = int(center_y - crop_size / 2)
x_max = int(center_x + crop_size / 2)
y_max = int(center_y + crop_size / 2)
# Adjust the crop boundaries to stay within the original image's dimensions
if x_min < 0:
context.services.logger.warning("FaceTools --> -X-axis padding reached image edge.")
x_max -= x_min
x_min = 0
elif x_max > mask.width:
context.services.logger.warning("FaceTools --> +X-axis padding reached image edge.")
x_min -= x_max - mask.width
x_max = mask.width
if y_min < 0:
context.services.logger.warning("FaceTools --> +Y-axis padding reached image edge.")
y_max -= y_min
y_min = 0
elif y_max > mask.height:
context.services.logger.warning("FaceTools --> -Y-axis padding reached image edge.")
y_min -= y_max - mask.height
y_max = mask.height
# Ensure the crop is square and adjust the boundaries if needed
if x_max - x_min != crop_size:
context.services.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.")
diff = crop_size - (x_max - x_min)
x_min -= diff // 2
x_max += diff - diff // 2
if y_max - y_min != crop_size:
context.services.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.")
diff = crop_size - (y_max - y_min)
y_min -= diff // 2
y_max += diff - diff // 2
context.services.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}")
# Crop the output image to the specified size with the center of the face mesh as the center.
mask = mask.crop((x_min, y_min, x_max, y_max))
bounded_image = image.crop((x_min, y_min, x_max, y_max))
# blur mask edge by small radius
mask = mask.filter(ImageFilter.GaussianBlur(radius=2))
return ExtractFaceData(
bounded_image=bounded_image,
bounded_mask=mask,
x_min=x_min,
y_min=y_min,
x_max=x_max,
y_max=y_max,
)
def get_faces_list(
context: InvocationContext,
image: ImageType,
should_chunk: bool,
minimum_confidence: float,
x_offset: float,
y_offset: float,
draw_mesh: bool = True,
) -> list[FaceResultDataWithId]:
result = []
# Generate the face box mask and get the center of the face.
if not should_chunk:
context.services.logger.info("FaceTools --> Attempting full image face detection.")
result = generate_face_box_mask(
context=context,
minimum_confidence=minimum_confidence,
x_offset=x_offset,
y_offset=y_offset,
pil_image=image,
chunk_x_offset=0,
chunk_y_offset=0,
draw_mesh=draw_mesh,
check_bounds=False,
)
if should_chunk or len(result) == 0:
context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
width, height = image.size
image_chunks = []
x_offsets = []
y_offsets = []
result = []
# If width == height, there's nothing more we can do... otherwise...
if width > height:
# Landscape - slice the image horizontally
fx = 0.0
steps = int(width * 2 / height)
while fx <= (width - height):
x = int(fx)
image_chunks.append(image.crop((x, 0, x + height - 1, height - 1)))
x_offsets.append(x)
y_offsets.append(0)
fx += (width - height) / steps
context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}")
elif height > width:
# Portrait - slice the image vertically
fy = 0.0
steps = int(height * 2 / width)
while fy <= (height - width):
y = int(fy)
image_chunks.append(image.crop((0, y, width - 1, y + width - 1)))
x_offsets.append(0)
y_offsets.append(y)
fy += (height - width) / steps
context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}")
for idx in range(len(image_chunks)):
context.services.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}")
result = result + generate_face_box_mask(
context=context,
minimum_confidence=minimum_confidence,
x_offset=x_offset,
y_offset=y_offset,
pil_image=image_chunks[idx],
chunk_x_offset=x_offsets[idx],
chunk_y_offset=y_offsets[idx],
draw_mesh=draw_mesh,
)
if len(result) == 0:
# Give up
context.services.logger.warning(
"FaceTools --> No face detected in chunked input image. Passing through original image."
)
all_faces = prepare_faces_list(result)
return all_faces
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.1")
class FaceOffInvocation(BaseInvocation):
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
image: ImageField = InputField(description="Image for face detection")
face_id: int = InputField(
default=0,
ge=0,
description="The face ID to process, numbered from 0. Multiple faces not supported. Find a face's ID with FaceIdentifier node.",
)
minimum_confidence: float = InputField(
default=0.5, description="Minimum confidence for face detection (lower if detection is failing)"
)
x_offset: float = InputField(default=0.0, description="X-axis offset of the mask")
y_offset: float = InputField(default=0.0, description="Y-axis offset of the mask")
padding: int = InputField(default=0, description="All-axis padding around the mask in pixels")
chunk: bool = InputField(
default=False,
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
)
def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]:
all_faces = get_faces_list(
context=context,
image=image,
should_chunk=self.chunk,
minimum_confidence=self.minimum_confidence,
x_offset=self.x_offset,
y_offset=self.y_offset,
draw_mesh=True,
)
if len(all_faces) == 0:
context.services.logger.warning("FaceOff --> No faces detected. Passing through original image.")
return None
if self.face_id > len(all_faces) - 1:
context.services.logger.warning(
f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image."
)
return None
face_data = extract_face(context=context, image=image, face=all_faces[self.face_id], padding=self.padding)
# Convert the input image to RGBA mode to ensure it has an alpha channel.
face_data["bounded_image"] = face_data["bounded_image"].convert("RGBA")
return face_data
def invoke(self, context: InvocationContext) -> FaceOffOutput:
image = context.services.images.get_pil_image(self.image.image_name)
result = self.faceoff(context=context, image=image)
if result is None:
result_image = image
result_mask = create_white_image(*image.size)
x = 0
y = 0
else:
result_image = result["bounded_image"]
result_mask = result["bounded_mask"]
x = result["x_min"]
y = result["y_min"]
image_dto = context.services.images.create(
image=result_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
mask_dto = context.services.images.create(
image=result_mask,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.MASK,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
output = FaceOffOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
mask=ImageField(image_name=mask_dto.image_name),
x=x,
y=y,
)
return output
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.1")
class FaceMaskInvocation(BaseInvocation):
"""Face mask creation using mediapipe face detection"""
image: ImageField = InputField(description="Image to face detect")
face_ids: str = InputField(
default="",
description="Comma-separated list of face ids to mask eg '0,2,7'. Numbered from 0. Leave empty to mask all. Find face IDs with FaceIdentifier node.",
)
minimum_confidence: float = InputField(
default=0.5, description="Minimum confidence for face detection (lower if detection is failing)"
)
x_offset: float = InputField(default=0.0, description="Offset for the X-axis of the face mask")
y_offset: float = InputField(default=0.0, description="Offset for the Y-axis of the face mask")
chunk: bool = InputField(
default=False,
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
)
invert_mask: bool = InputField(default=False, description="Toggle to invert the mask")
@validator("face_ids")
def validate_comma_separated_ints(cls, v) -> str:
comma_separated_ints_regex = re.compile(r"^\d*(,\d+)*$")
if comma_separated_ints_regex.match(v) is None:
raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")')
return v
def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult:
all_faces = get_faces_list(
context=context,
image=image,
should_chunk=self.chunk,
minimum_confidence=self.minimum_confidence,
x_offset=self.x_offset,
y_offset=self.y_offset,
draw_mesh=True,
)
mask_pil = create_white_image(*image.size)
id_range = list(range(0, len(all_faces)))
ids_to_extract = id_range
if self.face_ids != "":
parsed_face_ids = [int(id) for id in self.face_ids.split(",")]
# get requested face_ids that are in range
intersected_face_ids = set(parsed_face_ids) & set(id_range)
if len(intersected_face_ids) == 0:
id_range_str = ",".join([str(id) for id in id_range])
context.services.logger.warning(
f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image."
)
return FaceMaskResult(
image=image, # original image
mask=mask_pil, # white mask
)
ids_to_extract = list(intersected_face_ids)
for face_id in ids_to_extract:
face_data = extract_face(context=context, image=image, face=all_faces[face_id], padding=0)
face_mask_pil = face_data["bounded_mask"]
x_min = face_data["x_min"]
y_min = face_data["y_min"]
x_max = face_data["x_max"]
y_max = face_data["y_max"]
mask_pil.paste(
create_black_image(x_max - x_min, y_max - y_min),
box=(x_min, y_min),
mask=ImageOps.invert(face_mask_pil),
)
if self.invert_mask:
mask_pil = ImageOps.invert(mask_pil)
# Create an RGBA image with transparency
image = image.convert("RGBA")
return FaceMaskResult(
image=image,
mask=mask_pil,
)
def invoke(self, context: InvocationContext) -> FaceMaskOutput:
image = context.services.images.get_pil_image(self.image.image_name)
result = self.facemask(context=context, image=image)
image_dto = context.services.images.create(
image=result["image"],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
mask_dto = context.services.images.create(
image=result["mask"],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.MASK,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
output = FaceMaskOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
mask=ImageField(image_name=mask_dto.image_name),
)
return output
@invocation(
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.1"
)
class FaceIdentifierInvocation(BaseInvocation):
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
image: ImageField = InputField(description="Image to face detect")
minimum_confidence: float = InputField(
default=0.5, description="Minimum confidence for face detection (lower if detection is failing)"
)
chunk: bool = InputField(
default=False,
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
)
def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType:
image = image.copy()
all_faces = get_faces_list(
context=context,
image=image,
should_chunk=self.chunk,
minimum_confidence=self.minimum_confidence,
x_offset=0,
y_offset=0,
draw_mesh=False,
)
# Note - font may be found either in the repo if running an editable install, or in the venv if running a package install
font_path = [x for x in [Path(y, "inter/Inter-Regular.ttf") for y in font_assets.__path__] if x.exists()]
font = ImageFont.truetype(font_path[0].as_posix(), FONT_SIZE)
# Paste face IDs on the output image
draw = ImageDraw.Draw(image)
for face in all_faces:
x_coord = face["x_center"]
y_coord = face["y_center"]
text = str(face["face_id"])
# get bbox of the text so we can center the id on the face
_, _, bbox_w, bbox_h = draw.textbbox(xy=(0, 0), text=text, font=font, stroke_width=FONT_STROKE_WIDTH)
x = x_coord - bbox_w / 2
y = y_coord - bbox_h / 2
draw.text(
xy=(x, y),
text=str(text),
fill=(255, 255, 255, 255),
font=font,
stroke_width=FONT_STROKE_WIDTH,
stroke_fill=(0, 0, 0, 255),
)
# Create an RGBA image with transparency
image = image.convert("RGBA")
return image
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
result_image = self.faceidentifier(context=context, image=image)
image_dto = context.services.images.create(
image=result_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -8,12 +8,12 @@ import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, invocation
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
@@ -965,44 +965,3 @@ class ImageChannelMultiplyInvocation(BaseInvocation):
width=image_dto.width,
height=image_dto.height,
)
@invocation(
"save_image",
title="Save Image",
tags=["primitives", "image"],
category="primitives",
version="1.0.1",
use_cache=False,
)
class SaveImageInvocation(BaseInvocation):
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
image: ImageField = InputField(description=FieldDescriptions.image)
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
metadata: CoreMetadata = InputField(
default=None,
description=FieldDescriptions.core_metadata,
ui_hidden=True,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
board_id=self.board.board_id if self.board else None,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -269,7 +269,7 @@ class LaMaInfillInvocation(BaseInvocation):
)
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint")
class CV2InfillInvocation(BaseInvocation):
"""Infills transparent areas of an image using OpenCV Inpainting"""

View File

@@ -58,7 +58,9 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = InputField(
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
)
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)

View File

@@ -1,14 +1,12 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import List, Literal, Optional, Union
import einops
import numpy as np
import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import (
@@ -210,7 +208,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: Union[float, List[float]] = InputField(
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float, title="CFG Scale"
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
@@ -220,6 +218,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
control: Union[ControlField, list[ControlField]] = InputField(
default=None,
description=FieldDescriptions.control,
input=Input.Connection,
ui_order=5,
)
@@ -546,7 +545,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
with (
ExitStack() as exit_stack,
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
set_seamless(unet_info.context.model, self.unet.seamless_axes),
set_seamless(unet_info.context.model, **self.unet.seamless.dict()),
unet_info as unet,
):
latents = latents.to(device=unet.device, dtype=unet.dtype)
@@ -649,7 +648,7 @@ class LatentsToImageInvocation(BaseInvocation):
context=context,
)
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
with set_seamless(vae_info.context.model, **self.vae.seamless.dict()), vae_info as vae:
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
@@ -858,7 +857,8 @@ class ImageToLatentsInvocation(BaseInvocation):
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)
@@ -885,18 +885,6 @@ class ImageToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
@staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
return latents
@_encode_to_tensor.register
@staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
return vae.encode(image_tensor).latents
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
class BlendLatentsInvocation(BaseInvocation):

View File

@@ -54,38 +54,17 @@ class DivideInvocation(BaseInvocation):
return IntegerOutput(value=int(self.a / self.b))
@invocation(
"rand_int",
title="Random Integer",
tags=["math", "random"],
category="math",
version="1.0.0",
use_cache=False,
)
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0")
class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer."""
low: int = InputField(default=0, description=FieldDescriptions.inclusive_low)
high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high)
low: int = InputField(default=0, description="The inclusive low value")
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(value=np.random.randint(self.low, self.high))
@invocation("rand_float", title="Random Float", tags=["math", "float", "random"], category="math", version="1.0.0")
class RandomFloatInvocation(BaseInvocation):
"""Outputs a single random float"""
low: float = InputField(default=0.0, description=FieldDescriptions.inclusive_low)
high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high)
decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places)
def invoke(self, context: InvocationContext) -> FloatOutput:
random_float = np.random.uniform(self.low, self.high)
rounded_float = round(random_float, self.decimals)
return FloatOutput(value=rounded_float)
@invocation(
"float_to_int",
title="Float To Integer",

View File

@@ -12,9 +12,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from ...version import __version__
@@ -27,18 +25,6 @@ class LoRAMetadataField(BaseModelExcludeNull):
weight: float = Field(description="The weight of the LoRA model")
class IPAdapterMetadataField(BaseModelExcludeNull):
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
weight: float = Field(description="The weight of the IP-Adapter model")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
class CoreMetadata(BaseModelExcludeNull):
"""Core generation metadata for an image generated in InvokeAI."""
@@ -56,13 +42,11 @@ class CoreMetadata(BaseModelExcludeNull):
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: Optional[int] = Field(
default=None,
clip_skip: int = Field(
description="The number of skipped CLIP layers",
)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
ipAdapters: list[IPAdapterMetadataField] = Field(description="The IP Adapters used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
vae: Optional[VAEModelField] = Field(
default=None,
@@ -132,13 +116,11 @@ class MetadataAccumulatorInvocation(BaseInvocation):
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
steps: int = InputField(description="The number of steps used for inference")
scheduler: str = InputField(description="The scheduler used for inference")
clip_skip: Optional[int] = Field(
default=None,
clip_skip: int = InputField(
description="The number of skipped CLIP layers",
)
model: MainModelField = InputField(description="The main model used for inference")
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
ipAdapters: list[IPAdapterMetadataField] = InputField(description="The IP Adapters used for inference")
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
strength: Optional[float] = InputField(
default=None,

View File

@@ -18,6 +18,13 @@ from .baseinvocation import (
)
class SeamlessSettings(BaseModel):
axes: List[str] = Field(description="Axes('x' and 'y') to which apply seamless")
skipped_layers: int = Field(description="How much down layers skip when applying seamless")
skip_second_resnet: bool = Field(description="Skip or not second resnet in down blocks when applying seamless")
skip_conv2: bool = Field(description="Skip or not conv2 in down blocks when applying seamless")
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
@@ -33,7 +40,7 @@ class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
class ClipField(BaseModel):
@@ -46,7 +53,7 @@ class ClipField(BaseModel):
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
@invocation_output("model_loader_output")
@@ -388,6 +395,11 @@ class SeamlessModeInvocation(BaseInvocation):
)
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
skipped_layers: int = InputField(default=0, input=Input.Any, description="How much model's down layers to skip")
skip_second_resnet: bool = InputField(
default=True, input=Input.Any, description="Skip or not second resnet in down layers"
)
skip_conv2: bool = InputField(default=True, input=Input.Any, description="Skip or not conv2 in down layers")
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
@@ -402,8 +414,18 @@ class SeamlessModeInvocation(BaseInvocation):
seamless_axes_list.append("y")
if unet is not None:
unet.seamless_axes = seamless_axes_list
unet.seamless = SeamlessSettings(
axes=seamless_axes_list,
skipped_layers=self.skipped_layers,
skip_second_resnet=self.skip_second_resnet,
skip_conv2=self.skip_conv2,
)
if vae is not None:
vae.seamless_axes = seamless_axes_list
vae.seamless = SeamlessSettings(
axes=seamless_axes_list,
skipped_layers=self.skipped_layers,
skip_second_resnet=self.skip_second_resnet,
skip_conv2=self.skip_conv2,
)
return SeamlessModeOutput(unet=unet, vae=vae)

View File

@@ -166,6 +166,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
default=7.5,
ge=1,
description=FieldDescriptions.cfg_scale,
ui_type=UIType.Float,
)
scheduler: SAMPLER_NAME_VALUES = InputField(
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
@@ -178,6 +179,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
default=None,
description=FieldDescriptions.control,
ui_type=UIType.Control,
)
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")

View File

@@ -226,12 +226,6 @@ class ImageField(BaseModel):
image_name: str = Field(description="The name of the image")
class BoardField(BaseModel):
"""A board primitive field"""
board_id: str = Field(description="The id of the board")
@invocation_output("image_output")
class ImageOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""

View File

@@ -10,14 +10,7 @@ from invokeai.app.invocations.primitives import StringCollectionOutput
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
@invocation(
"dynamic_prompt",
title="Dynamic Prompt",
tags=["prompt", "collection"],
category="prompt",
version="1.0.0",
use_cache=False,
)
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0")
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""

View File

@@ -53,20 +53,24 @@ class BoardImageRecordStorageBase(ABC):
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
def __init__(self, filename: str) -> None:
super().__init__()
self._conn = conn
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = lock
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:

View File

@@ -1,5 +1,6 @@
import sqlite3
import threading
import uuid
from abc import ABC, abstractmethod
from typing import Optional, Union, cast
@@ -7,7 +8,6 @@ from pydantic import BaseModel, Extra, Field
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
from invokeai.app.util.misc import uuid_string
class BoardChanges(BaseModel, extra=Extra.forbid):
@@ -87,20 +87,24 @@ class BoardRecordStorageBase(ABC):
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
def __init__(self, filename: str) -> None:
super().__init__()
self._conn = conn
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = lock
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
@@ -170,7 +174,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_name: str,
) -> BoardRecord:
try:
board_id = uuid_string()
board_id = str(uuid.uuid4())
self._lock.acquire()
self._cursor.execute(
"""--sql

View File

@@ -16,7 +16,7 @@ import pydoc
import sys
from argparse import ArgumentParser
from pathlib import Path
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
from typing import ClassVar, Dict, List, Literal, Union, get_args, get_origin, get_type_hints
from omegaconf import DictConfig, ListConfig, OmegaConf
from pydantic import BaseSettings
@@ -39,10 +39,10 @@ class InvokeAISettings(BaseSettings):
read from an omegaconf .yaml file.
"""
initconf: ClassVar[Optional[DictConfig]] = None
initconf: ClassVar[DictConfig] = None
argparse_groups: ClassVar[Dict] = {}
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
def parse_args(self, argv: list = sys.argv[1:]):
parser = self.get_parser()
opt, unknown_opts = parser.parse_known_args(argv)
if len(unknown_opts) > 0:
@@ -83,8 +83,7 @@ class InvokeAISettings(BaseSettings):
else:
settings_stanza = "Uncategorized"
env_prefix = getattr(cls.Config, "env_prefix", None)
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
initconf = (
cls.initconf.get(settings_stanza)
@@ -117,8 +116,8 @@ class InvokeAISettings(BaseSettings):
field.default = current_default
@classmethod
def cmd_name(cls, command_field: str = "type") -> str:
hints = get_type_hints(cls)
def cmd_name(self, command_field: str = "type") -> str:
hints = get_type_hints(self)
if command_field in hints:
return get_args(hints[command_field])[0]
else:
@@ -134,12 +133,16 @@ class InvokeAISettings(BaseSettings):
return parser
@classmethod
def _excluded(cls) -> List[str]:
def add_subparser(cls, parser: argparse.ArgumentParser):
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
@classmethod
def _excluded(self) -> List[str]:
# internal fields that shouldn't be exposed as command line options
return ["type", "initconf"]
@classmethod
def _excluded_from_yaml(cls) -> List[str]:
def _excluded_from_yaml(self) -> List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return [
"type",

View File

@@ -194,8 +194,8 @@ class InvokeAIAppConfig(InvokeAISettings):
setting environment variables INVOKEAI_<setting>.
"""
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
singleton_init: ClassVar[Optional[Dict]] = None
singleton_config: ClassVar[InvokeAIAppConfig] = None
singleton_init: ClassVar[Dict] = None
# fmt: off
type: Literal["InvokeAI"] = "InvokeAI"
@@ -234,35 +234,29 @@ class InvokeAIAppConfig(InvokeAISettings):
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
log_sql : bool = Field(default=False, description="Log SQL queries", category="Logging")
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", category="Development")
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
# CACHE
ram : Union[float, Literal["auto"]] = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
# DEVICE
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", category="Device", )
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", category="Device", )
device : Literal[tuple(["auto", "cpu", "cuda", "cuda:1", "mps"])] = Field(default="auto", description="Generation device", category="Device", )
precision: Literal[tuple(["auto", "float16", "float32", "autocast"])] = Field(default="auto", description="Floating point precision", category="Device", )
# GENERATION
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
attention_type : Literal[tuple(["auto", "normal", "xformers", "sliced", "torch-sdp"])] = Field(default="auto", description="Attention type", category="Generation", )
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
# QUEUE
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
# NODES
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", category="Nodes", )
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
@@ -277,9 +271,8 @@ class InvokeAIAppConfig(InvokeAISettings):
class Config:
validate_assignment = True
env_prefix = "INVOKEAI"
def parse_args(self, argv: Optional[list[str]] = None, conf: Optional[DictConfig] = None, clobber=False):
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
"""
Update settings with contents of init file, environment, and
command-line settings.
@@ -290,16 +283,12 @@ class InvokeAIAppConfig(InvokeAISettings):
# Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option.
super().parse_args(argv)
loaded_conf = None
if conf is None:
try:
loaded_conf = OmegaConf.load(self.root_dir / INIT_FILE)
conf = OmegaConf.load(self.root_dir / INIT_FILE)
except Exception:
pass
if isinstance(loaded_conf, DictConfig):
InvokeAISettings.initconf = loaded_conf
else:
InvokeAISettings.initconf = conf
InvokeAISettings.initconf = conf
# parse args again in order to pick up settings in configuration file
super().parse_args(argv)
@@ -387,6 +376,13 @@ class InvokeAIAppConfig(InvokeAISettings):
"""
return self._resolve(self.models_dir)
@property
def autoconvert_path(self) -> Path:
"""
Path to the directory containing models to be imported automatically at startup.
"""
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
# the following methods support legacy calls leftover from the Globals era
@property
def full_precision(self) -> bool:
@@ -409,11 +405,11 @@ class InvokeAIAppConfig(InvokeAISettings):
return True
@property
def ram_cache_size(self) -> Union[Literal["auto"], float]:
def ram_cache_size(self) -> float:
return self.max_cache_size or self.ram
@property
def vram_cache_size(self) -> Union[Literal["auto"], float]:
def vram_cache_size(self) -> float:
return self.max_vram_cache_size or self.vram
@property

View File

@@ -10,58 +10,57 @@ default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
def create_text_to_image() -> LibraryGraph:
graph = Graph(
nodes={
"width": IntegerInvocation(id="width", value=512),
"height": IntegerInvocation(id="height", value=512),
"seed": IntegerInvocation(id="seed", value=-1),
"3": NoiseInvocation(id="3"),
"4": CompelInvocation(id="4"),
"5": CompelInvocation(id="5"),
"6": DenoiseLatentsInvocation(id="6"),
"7": LatentsToImageInvocation(id="7"),
"8": ImageNSFWBlurInvocation(id="8"),
},
edges=[
Edge(
source=EdgeConnection(node_id="width", field="value"),
destination=EdgeConnection(node_id="3", field="width"),
),
Edge(
source=EdgeConnection(node_id="height", field="value"),
destination=EdgeConnection(node_id="3", field="height"),
),
Edge(
source=EdgeConnection(node_id="seed", field="value"),
destination=EdgeConnection(node_id="3", field="seed"),
),
Edge(
source=EdgeConnection(node_id="3", field="noise"),
destination=EdgeConnection(node_id="6", field="noise"),
),
Edge(
source=EdgeConnection(node_id="6", field="latents"),
destination=EdgeConnection(node_id="7", field="latents"),
),
Edge(
source=EdgeConnection(node_id="4", field="conditioning"),
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
),
Edge(
source=EdgeConnection(node_id="5", field="conditioning"),
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
),
Edge(
source=EdgeConnection(node_id="7", field="image"),
destination=EdgeConnection(node_id="8", field="image"),
),
],
)
return LibraryGraph(
id=default_text_to_image_graph_id,
name="t2i",
description="Converts text to an image",
graph=graph,
graph=Graph(
nodes={
"width": IntegerInvocation(id="width", value=512),
"height": IntegerInvocation(id="height", value=512),
"seed": IntegerInvocation(id="seed", value=-1),
"3": NoiseInvocation(id="3"),
"4": CompelInvocation(id="4"),
"5": CompelInvocation(id="5"),
"6": DenoiseLatentsInvocation(id="6"),
"7": LatentsToImageInvocation(id="7"),
"8": ImageNSFWBlurInvocation(id="8"),
},
edges=[
Edge(
source=EdgeConnection(node_id="width", field="value"),
destination=EdgeConnection(node_id="3", field="width"),
),
Edge(
source=EdgeConnection(node_id="height", field="value"),
destination=EdgeConnection(node_id="3", field="height"),
),
Edge(
source=EdgeConnection(node_id="seed", field="value"),
destination=EdgeConnection(node_id="3", field="seed"),
),
Edge(
source=EdgeConnection(node_id="3", field="noise"),
destination=EdgeConnection(node_id="6", field="noise"),
),
Edge(
source=EdgeConnection(node_id="6", field="latents"),
destination=EdgeConnection(node_id="7", field="latents"),
),
Edge(
source=EdgeConnection(node_id="4", field="conditioning"),
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
),
Edge(
source=EdgeConnection(node_id="5", field="conditioning"),
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
),
Edge(
source=EdgeConnection(node_id="7", field="image"),
destination=EdgeConnection(node_id="8", field="image"),
),
],
),
exposed_inputs=[
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),

View File

@@ -4,23 +4,21 @@ from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
from invokeai.app.util.misc import get_timestamp
class EventServiceBase:
queue_event: str = "queue_event"
session_event: str = "session_event"
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
"""Queue events are emitted to a room with queue_id as the room name"""
def __emit_session_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.queue_event,
event_name=EventServiceBase.session_event,
payload=dict(event=event_name, data=payload),
)
@@ -28,9 +26,6 @@ class EventServiceBase:
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
@@ -40,14 +35,11 @@ class EventServiceBase:
total_steps: int,
) -> None:
"""Emitted when there is generation progress"""
self.__emit_queue_event(
self.__emit_session_event(
event_name="generator_progress",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node_id=node.get("id"),
node=node,
source_node_id=source_node_id,
progress_image=progress_image.dict() if progress_image is not None else None,
step=step,
@@ -58,21 +50,15 @@ class EventServiceBase:
def emit_invocation_complete(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
self.__emit_session_event(
event_name="invocation_complete",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
@@ -82,9 +68,6 @@ class EventServiceBase:
def emit_invocation_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
@@ -92,12 +75,9 @@ class EventServiceBase:
error: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
self.__emit_session_event(
event_name="invocation_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
@@ -106,47 +86,28 @@ class EventServiceBase:
),
)
def emit_invocation_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
) -> None:
def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
"""Emitted when an invocation has started"""
self.__emit_queue_event(
self.__emit_session_event(
event_name="invocation_started",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
),
)
def emit_graph_execution_complete(
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
) -> None:
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_queue_event(
self.__emit_session_event(
event_name="graph_execution_state_complete",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
),
)
def emit_model_load_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
@@ -154,12 +115,9 @@ class EventServiceBase:
submodel: SubModelType,
) -> None:
"""Emitted when a model is requested"""
self.__emit_queue_event(
self.__emit_session_event(
event_name="model_load_started",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
@@ -170,9 +128,6 @@ class EventServiceBase:
def emit_model_load_completed(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
@@ -181,12 +136,9 @@ class EventServiceBase:
model_info: ModelInfo,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
self.__emit_session_event(
event_name="model_load_completed",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
@@ -200,20 +152,14 @@ class EventServiceBase:
def emit_session_retrieval_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when session retrieval fails"""
self.__emit_queue_event(
self.__emit_session_event(
event_name="session_retrieval_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
error_type=error_type,
error=error,
@@ -222,78 +168,18 @@ class EventServiceBase:
def emit_invocation_retrieval_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when invocation retrieval fails"""
self.__emit_queue_event(
self.__emit_session_event(
event_name="invocation_retrieval_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node_id=node_id,
error_type=error_type,
error=error,
),
)
def emit_session_canceled(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
) -> None:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
),
)
def emit_queue_item_status_changed(self, session_queue_item: SessionQueueItem) -> None:
"""Emitted when a queue item's status changes"""
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload=dict(
queue_id=session_queue_item.queue_id,
queue_item_id=session_queue_item.item_id,
status=session_queue_item.status,
batch_id=session_queue_item.batch_id,
session_id=session_queue_item.session_id,
error=session_queue_item.error,
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
),
)
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
"""Emitted when a batch is enqueued"""
self.__emit_queue_event(
event_name="batch_enqueued",
payload=dict(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
),
)
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared"""
self.__emit_queue_event(
event_name="queue_cleared",
payload=dict(queue_id=queue_id),
)

View File

@@ -2,14 +2,13 @@
import copy
import itertools
from typing import Annotated, Any, Optional, Union, cast, get_args, get_origin, get_type_hints
import uuid
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import BaseModel, root_validator, validator
from pydantic.fields import Field
from invokeai.app.util.misc import uuid_string
# Importing * is bad karma but needed here for node detection
from ..invocations import * # noqa: F401 F403
from ..invocations.baseinvocation import (
@@ -117,10 +116,6 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if from_type is int and to_type is float:
return True
# allow int|float -> str, pydantic will cast for us
if (from_type is int or from_type is float) and to_type is str:
return True
# if not issubclass(from_type, to_type):
if not is_union_subtype(from_type, to_type):
return False
@@ -142,31 +137,19 @@ def are_connections_compatible(
return are_connection_types_compatible(from_node_field, to_node_field)
class NodeAlreadyInGraphError(ValueError):
class NodeAlreadyInGraphError(Exception):
pass
class InvalidEdgeError(ValueError):
class InvalidEdgeError(Exception):
pass
class NodeNotFoundError(ValueError):
class NodeNotFoundError(Exception):
pass
class NodeAlreadyExecutedError(ValueError):
pass
class DuplicateNodeIdError(ValueError):
pass
class NodeFieldNotFoundError(ValueError):
pass
class NodeIdMismatchError(ValueError):
class NodeAlreadyExecutedError(Exception):
pass
@@ -244,7 +227,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid_string)
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict
@@ -254,59 +237,6 @@ class Graph(BaseModel):
default_factory=list,
)
@root_validator
def validate_nodes_and_edges(cls, values):
"""Validates that all edges match nodes in the graph"""
nodes = cast(Optional[dict[str, BaseInvocation]], values.get("nodes"))
edges = cast(Optional[list[Edge]], values.get("edges"))
if nodes is not None:
# Validate that all node ids are unique
node_ids = [n.id for n in nodes.values()]
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
# Validate that all node ids match the keys in the nodes dict
for k, v in nodes.items():
if k != v.id:
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
if edges is not None and nodes is not None:
# Validate that all edges match nodes in the graph
node_ids = set([e.source.node_id for e in edges] + [e.destination.node_id for e in edges])
missing_node_ids = [node_id for node_id in node_ids if node_id not in nodes]
if missing_node_ids:
raise NodeNotFoundError(
f"All edges must reference nodes in the graph, missing nodes: {missing_node_ids}"
)
# Validate that all edge fields match node fields in the graph
for edge in edges:
source_node = nodes.get(edge.source.node_id, None)
if source_node is None:
raise NodeFieldNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
destination_node = nodes.get(edge.destination.node_id, None)
if destination_node is None:
raise NodeFieldNotFoundError(
f"Edge destination node {edge.destination.node_id} does not exist in the graph"
)
# output fields are not on the node object directly, they are on the output type
if edge.source.field not in source_node.get_output_type().__fields__:
raise NodeFieldNotFoundError(
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
)
# input fields are on the node
if edge.destination.field not in destination_node.__fields__:
raise NodeFieldNotFoundError(
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
)
return values
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
@@ -767,7 +697,8 @@ class Graph(BaseModel):
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state", default_factory=uuid_string)
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
# TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed")
@@ -916,7 +847,7 @@ class GraphExecutionState(BaseModel):
new_node = copy.deepcopy(node)
# Create the node id (use a random uuid)
new_node.id = uuid_string()
new_node.id = str(uuid.uuid4())
# Set the iteration index for iteration invocations
if isinstance(new_node, IterateInvocation):
@@ -1151,7 +1082,7 @@ class ExposedNodeOutput(BaseModel):
class LibraryGraph(BaseModel):
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid_string)
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
graph: Graph = Field(description="The graph")
name: str = Field(description="The name of the graph")
description: str = Field(description="The description of the graph")

View File

@@ -148,20 +148,24 @@ class ImageRecordStorageBase(ABC):
class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
def __init__(self, filename: str) -> None:
super().__init__()
self._conn = conn
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = lock
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
@@ -584,7 +588,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
ORDER BY images.starred DESC, images.created_at DESC
ORDER BY images.created_at DESC
LIMIT 1;
""",
(board_id,),

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Optional
from PIL.Image import Image as PILImageType
@@ -38,29 +38,6 @@ if TYPE_CHECKING:
class ImageServiceABC(ABC):
"""High-level service for image management."""
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
"""Register a callback for when an image is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an image is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: ImageDTO) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)
@abstractmethod
def create(
self,
@@ -184,7 +161,6 @@ class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(self, services: ImageServiceDependencies):
super().__init__()
self._services = services
def create(
@@ -241,7 +217,6 @@ class ImageService(ImageServiceABC):
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
@@ -260,9 +235,7 @@ class ImageService(ImageServiceABC):
) -> ImageDTO:
try:
self._services.image_records.update(image_name, changes)
image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
return self.get_dto(image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
@@ -401,7 +374,6 @@ class ImageService(ImageServiceABC):
try:
self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name)
self._on_deleted(image_name)
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image record")
raise
@@ -418,8 +390,6 @@ class ImageService(ImageServiceABC):
for image_name in image_names:
self._services.image_files.delete(image_name)
self._services.image_records.delete_many(image_names)
for image_name in image_names:
self._on_deleted(image_name)
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records")
raise
@@ -436,7 +406,6 @@ class ImageService(ImageServiceABC):
count = len(image_names)
for image_name in image_names:
self._services.image_files.delete(image_name)
self._on_deleted(image_name)
return count
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records")

View File

@@ -1,62 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
class InvocationCacheBase(ABC):
"""
Base class for invocation caches.
When an invocation is executed, it is hashed and its output stored in the cache.
When new invocations are executed, if they are flagged with `use_cache`, they
will attempt to pull their value from the cache before executing.
Implementations should register for the `on_deleted` event of the `images` and `latents`
services, and delete any cached outputs that reference the deleted image or latent.
See the memory implementation for an example.
Implementations should respect the `node_cache_size` configuration value, and skip all
cache logic if the value is set to 0.
"""
@abstractmethod
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
"""Retrieves an invocation output from the cache"""
pass
@abstractmethod
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
"""Stores an invocation output in the cache"""
pass
@abstractmethod
def delete(self, key: Union[int, str]) -> None:
"""Deletes an invocation output from the cache"""
pass
@abstractmethod
def clear(self) -> None:
"""Clears the cache"""
pass
@abstractmethod
def create_key(self, invocation: BaseInvocation) -> int:
"""Gets the key for the invocation's cache item"""
pass
@abstractmethod
def disable(self) -> None:
"""Disables the cache, overriding the max cache size"""
pass
@abstractmethod
def enable(self) -> None:
"""Enables the cache, letting the the max cache size take effect"""
pass
@abstractmethod
def get_status(self) -> InvocationCacheStatus:
"""Returns the status of the cache"""
pass

View File

@@ -1,9 +0,0 @@
from pydantic import BaseModel, Field
class InvocationCacheStatus(BaseModel):
size: int = Field(description="The current size of the invocation cache")
hits: int = Field(description="The number of cache hits")
misses: int = Field(description="The number of cache misses")
enabled: bool = Field(description="Whether the invocation cache is enabled")
max_size: int = Field(description="The maximum size of the invocation cache")

View File

@@ -1,126 +0,0 @@
from collections import OrderedDict
from dataclasses import dataclass, field
from threading import Lock
from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.app.services.invoker import Invoker
@dataclass(order=True)
class CachedItem:
invocation_output: BaseInvocationOutput = field(compare=False)
invocation_output_json: str = field(compare=False)
class MemoryInvocationCache(InvocationCacheBase):
_cache: OrderedDict[Union[int, str], CachedItem]
_max_cache_size: int
_disabled: bool
_hits: int
_misses: int
_invoker: Invoker
_lock: Lock
def __init__(self, max_cache_size: int = 0) -> None:
self._cache = OrderedDict()
self._max_cache_size = max_cache_size
self._disabled = False
self._hits = 0
self._misses = 0
self._lock = Lock()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
if self._max_cache_size == 0:
return
self._invoker.services.images.on_deleted(self._delete_by_match)
self._invoker.services.latents.on_deleted(self._delete_by_match)
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
with self._lock:
if self._max_cache_size == 0 or self._disabled:
return None
item = self._cache.get(key, None)
if item is not None:
self._hits += 1
self._cache.move_to_end(key)
return item.invocation_output
self._misses += 1
return None
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
with self._lock:
if self._max_cache_size == 0 or self._disabled or key in self._cache:
return
# If the cache is full, we need to remove the least used
number_to_delete = len(self._cache) + 1 - self._max_cache_size
self._delete_oldest_access(number_to_delete)
self._cache[key] = CachedItem(invocation_output, invocation_output.json())
def _delete_oldest_access(self, number_to_delete: int) -> None:
number_to_delete = min(number_to_delete, len(self._cache))
for _ in range(number_to_delete):
self._cache.popitem(last=False)
def _delete(self, key: Union[int, str]) -> None:
if self._max_cache_size == 0:
return
if key in self._cache:
del self._cache[key]
def delete(self, key: Union[int, str]) -> None:
with self._lock:
return self._delete(key)
def clear(self, *args, **kwargs) -> None:
with self._lock:
if self._max_cache_size == 0:
return
self._cache.clear()
self._misses = 0
self._hits = 0
@staticmethod
def create_key(invocation: BaseInvocation) -> int:
return hash(invocation.json(exclude={"id"}))
def disable(self) -> None:
with self._lock:
if self._max_cache_size == 0:
return
self._disabled = True
def enable(self) -> None:
with self._lock:
if self._max_cache_size == 0:
return
self._disabled = False
def get_status(self) -> InvocationCacheStatus:
with self._lock:
return InvocationCacheStatus(
hits=self._hits,
misses=self._misses,
enabled=not self._disabled and self._max_cache_size > 0,
size=len(self._cache),
max_size=self._max_cache_size,
)
def _delete_by_match(self, to_match: str) -> None:
with self._lock:
if self._max_cache_size == 0:
return
keys_to_delete = set()
for key, cached_item in self._cache.items():
if to_match in cached_item.invocation_output_json:
keys_to_delete.add(key)
if not keys_to_delete:
return
for key in keys_to_delete:
self._delete(key)
self._invoker.services.logger.debug(
f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}"
)

View File

@@ -11,13 +11,6 @@ from pydantic import BaseModel, Field
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
invocation_id: str = Field(description="The ID of the node being invoked")
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
session_queue_item_id: int = Field(
description="The ID of session queue item from which this invocation queue item came"
)
session_queue_batch_id: str = Field(
description="The ID of the session batch from which this invocation queue item came"
)
invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time)

View File

@@ -12,15 +12,12 @@ if TYPE_CHECKING:
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.images import ImageServiceABC
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
from invokeai.app.services.invoker import InvocationProcessorABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
class InvocationServices:
@@ -31,8 +28,8 @@ class InvocationServices:
boards: "BoardServiceABC"
configuration: "InvokeAIAppConfig"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
graph_library: "ItemStorageABC[LibraryGraph]"
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
graph_library: "ItemStorageABC"["LibraryGraph"]
images: "ImageServiceABC"
latents: "LatentsStorageBase"
logger: "Logger"
@@ -40,9 +37,6 @@ class InvocationServices:
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
session_queue: "SessionQueueBase"
session_processor: "SessionProcessorBase"
invocation_cache: "InvocationCacheBase"
def __init__(
self,
@@ -50,8 +44,8 @@ class InvocationServices:
boards: "BoardServiceABC",
configuration: "InvokeAIAppConfig",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
graph_library: "ItemStorageABC[LibraryGraph]",
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
graph_library: "ItemStorageABC"["LibraryGraph"],
images: "ImageServiceABC",
latents: "LatentsStorageBase",
logger: "Logger",
@@ -59,12 +53,10 @@ class InvocationServices:
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase",
):
self.board_images = board_images
self.boards = boards
self.boards = boards
self.configuration = configuration
self.events = events
self.graph_execution_manager = graph_execution_manager
@@ -76,6 +68,3 @@ class InvocationServices:
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue
self.session_queue = session_queue
self.session_processor = session_processor
self.invocation_cache = invocation_cache

View File

@@ -17,14 +17,7 @@ class Invoker:
self.services = services
self._start()
def invoke(
self,
session_queue_id: str,
session_queue_item_id: int,
session_queue_batch_id: str,
graph_execution_state: GraphExecutionState,
invoke_all: bool = False,
) -> Optional[str]:
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
@@ -39,9 +32,7 @@ class Invoker:
# Queue the invocation
self.services.queue.put(
InvocationQueueItem(
session_queue_id=session_queue_id,
session_queue_item_id=session_queue_item_id,
session_queue_batch_id=session_queue_batch_id,
# session_id = session.id,
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
invoke_all=invoke_all,

View File

@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Callable, Dict, Optional, Union
from typing import Dict, Optional, Union
import torch
@@ -11,13 +11,6 @@ import torch
class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents."""
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
@abstractmethod
def get(self, name: str) -> torch.Tensor:
pass
@@ -30,22 +23,6 @@ class LatentsStorageBase(ABC):
def delete(self, name: str) -> None:
pass
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: torch.Tensor) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)
class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
@@ -56,7 +33,6 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
__underlying_storage: LatentsStorageBase
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
super().__init__()
self.__underlying_storage = underlying_storage
self.__cache = dict()
self.__cache_ids = Queue()
@@ -74,13 +50,11 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.save(name, data)
self.__set_cache(name, data)
self._on_changed(data)
def delete(self, name: str) -> None:
self.__underlying_storage.delete(name)
if name in self.__cache:
del self.__cache[name]
self._on_deleted(name)
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
return None if name not in self.__cache else self.__cache[name]

View File

@@ -525,7 +525,7 @@ class ModelManagerService(ModelManagerServiceBase):
def _emit_load_event(
self,
context: InvocationContext,
context,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
@@ -537,9 +537,6 @@ class ModelManagerService(ModelManagerServiceBase):
if model_info:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_name=model_name,
base_model=base_model,
@@ -549,9 +546,6 @@ class ModelManagerService(ModelManagerServiceBase):
)
else:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_name=model_name,
base_model=base_model,

View File

@@ -1,7 +1,6 @@
import time
import traceback
from threading import BoundedSemaphore, Event, Thread
from typing import Optional
import invokeai.backend.util.logging as logger
@@ -38,11 +37,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try:
self.__threadLimit.acquire()
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
queue_item: Optional[InvocationQueueItem] = None
while not stop_event.is_set():
try:
queue_item = self.__invoker.services.queue.get()
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
except Exception as e:
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
@@ -50,6 +48,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# do not hammer the queue
time.sleep(0.5)
continue
try:
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id
@@ -57,9 +56,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
@@ -71,9 +67,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id,
node_id=queue_item.invocation_id,
error_type=e.__class__.__name__,
@@ -86,9 +79,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send starting event
self.__invoker.services.events.emit_invocation_started(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
@@ -99,17 +89,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_id = graph_execution_state.id
model_manager = self.__invoker.services.model_manager
with statistics.collect_stats(invocation, graph_id, model_manager):
# use the internal invoke_internal(), which wraps the node's invoke() method,
# which handles a few things:
# - nodes that require a value, but get it only from a connection
# - referencing the invocation cache instead of executing the node
# use the internal invoke_internal(), which wraps the node's invoke() method in
# this accomodates nodes which require a value, but get it only from a
# connection
outputs = invocation.invoke_internal(
InvocationContext(
services=self.__invoker.services,
graph_execution_state_id=graph_execution_state.id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
queue_batch_id=queue_item.session_queue_batch_id,
)
)
@@ -125,9 +111,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
@@ -155,9 +138,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event
self.__invoker.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
@@ -175,19 +155,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete:
try:
self.__invoker.invoke(
session_queue_batch_id=queue_item.session_queue_batch_id,
session_queue_item_id=queue_item.session_queue_item_id,
session_queue_id=queue_item.session_queue_id,
graph_execution_state=graph_execution_state,
invoke_all=True,
)
self.__invoker.invoke(graph_execution_state, invoke_all=True)
except Exception as e:
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
self.__invoker.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
@@ -195,12 +166,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error=traceback.format_exc(),
)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
)
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
except KeyboardInterrupt:
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor

View File

@@ -1,8 +1,7 @@
import uuid
from abc import ABC, abstractmethod
from enum import Enum, EnumMeta
from invokeai.app.util.misc import uuid_string
class ResourceType(str, Enum, metaclass=EnumMeta):
"""Enum for resource types."""
@@ -26,6 +25,6 @@ class SimpleNameService(NameServiceBase):
# TODO: Add customizable naming schemes
def create_image_name(self) -> str:
uuid_str = uuid_string()
uuid_str = str(uuid.uuid4())
filename = f"{uuid_str}.png"
return filename

View File

@@ -1,28 +0,0 @@
from abc import ABC, abstractmethod
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
class SessionProcessorBase(ABC):
"""
Base class for session processor.
The session processor is responsible for executing sessions. It runs a simple polling loop,
checking the session queue for new sessions to execute. It must coordinate with the
invocation queue to ensure only one session is executing at a time.
"""
@abstractmethod
def resume(self) -> SessionProcessorStatus:
"""Starts or resumes the session processor"""
pass
@abstractmethod
def pause(self) -> SessionProcessorStatus:
"""Pauses the session processor"""
pass
@abstractmethod
def get_status(self) -> SessionProcessorStatus:
"""Gets the status of the session processor"""
pass

View File

@@ -1,6 +0,0 @@
from pydantic import BaseModel, Field
class SessionProcessorStatus(BaseModel):
is_started: bool = Field(description="Whether the session processor is started")
is_processing: bool = Field(description="Whether a session is being processed")

View File

@@ -1,135 +0,0 @@
from threading import BoundedSemaphore
from threading import Event as ThreadEvent
from threading import Thread
from typing import Optional
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase
from .session_processor_common import SessionProcessorStatus
POLLING_INTERVAL = 1
THREAD_LIMIT = 1
class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker) -> None:
self.__invoker: Invoker = invoker
self.__queue_item: Optional[SessionQueueItem] = None
self.__resume_event = ThreadEvent()
self.__stop_event = ThreadEvent()
self.__poll_now_event = ThreadEvent()
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
self.__thread = Thread(
name="session_processor",
target=self.__process,
kwargs=dict(
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
),
)
self.__thread.start()
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
def _poll_now(self) -> None:
self.__poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
# This was a match statement, but match is not supported on python 3.9
if event_name in [
"graph_execution_state_complete",
"invocation_error",
"session_retrieval_error",
"invocation_retrieval_error",
]:
self.__queue_item = None
self._poll_now()
elif (
event_name == "session_canceled"
and self.__queue_item is not None
and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"]
):
self.__queue_item = None
self._poll_now()
elif event_name == "batch_enqueued":
self._poll_now()
elif event_name == "queue_cleared":
self.__queue_item = None
self._poll_now()
def resume(self) -> SessionProcessorStatus:
if not self.__resume_event.is_set():
self.__resume_event.set()
return self.get_status()
def pause(self) -> SessionProcessorStatus:
if self.__resume_event.is_set():
self.__resume_event.clear()
return self.get_status()
def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus(
is_started=self.__resume_event.is_set(),
is_processing=self.__queue_item is not None,
)
def __process(
self,
stop_event: ThreadEvent,
poll_now_event: ThreadEvent,
resume_event: ThreadEvent,
):
try:
stop_event.clear()
resume_event.set()
self.__threadLimit.acquire()
queue_item: Optional[SessionQueueItem] = None
self.__invoker.services.logger
while not stop_event.is_set():
poll_now_event.clear()
try:
# do not dequeue if there is already a session running
if self.__queue_item is None and resume_event.is_set():
queue_item = self.__invoker.services.session_queue.dequeue()
if queue_item is not None:
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
self.__queue_item = queue_item
self.__invoker.services.graph_execution_manager.set(queue_item.session)
self.__invoker.invoke(
session_queue_batch_id=queue_item.batch_id,
session_queue_id=queue_item.queue_id,
session_queue_item_id=queue_item.item_id,
graph_execution_state=queue_item.session,
invoke_all=True,
)
queue_item = None
if queue_item is None:
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(POLLING_INTERVAL)
continue
except Exception as e:
self.__invoker.services.logger.error(f"Error in session processor: {e}")
poll_now_event.wait(POLLING_INTERVAL)
continue
except Exception as e:
self.__invoker.services.logger.error(f"Fatal Error in session processor: {e}")
pass
finally:
stop_event.clear()
poll_now_event.clear()
self.__queue_item = None
self.__threadLimit.release()

View File

@@ -1,112 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.graph import Graph
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
EnqueueGraphResult,
IsEmptyResult,
IsFullResult,
PruneResult,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.models import CursorPaginatedResults
class SessionQueueBase(ABC):
"""Base class for session queue"""
@abstractmethod
def dequeue(self) -> Optional[SessionQueueItem]:
"""Dequeues the next session queue item."""
pass
@abstractmethod
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
"""Enqueues a single graph for execution."""
pass
@abstractmethod
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
"""Enqueues all permutations of a batch for execution."""
pass
@abstractmethod
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
"""Gets the currently-executing session queue item"""
pass
@abstractmethod
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
"""Gets the next session queue item (does not dequeue it)"""
pass
@abstractmethod
def clear(self, queue_id: str) -> ClearResult:
"""Deletes all session queue items"""
pass
@abstractmethod
def prune(self, queue_id: str) -> PruneResult:
"""Deletes all completed and errored session queue items"""
pass
@abstractmethod
def is_empty(self, queue_id: str) -> IsEmptyResult:
"""Checks if the queue is empty"""
pass
@abstractmethod
def is_full(self, queue_id: str) -> IsFullResult:
"""Checks if the queue is empty"""
pass
@abstractmethod
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
"""Gets the status of the queue"""
pass
@abstractmethod
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
"""Gets the status of a batch"""
pass
@abstractmethod
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
"""Cancels a session queue item"""
pass
@abstractmethod
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
"""Cancels all queue items with matching batch IDs"""
pass
@abstractmethod
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
"""Cancels all queue items with matching queue ID"""
pass
@abstractmethod
def list_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
"""Gets a page of session queue items"""
pass
@abstractmethod
def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID"""
pass

View File

@@ -1,418 +0,0 @@
import datetime
import json
from itertools import chain, product
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
from pydantic import BaseModel, Field, StrictStr, parse_raw_as, root_validator, validator
from pydantic.json import pydantic_encoder
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.graph import Graph, GraphExecutionState, NodeNotFoundError
from invokeai.app.util.misc import uuid_string
# region Errors
class BatchZippedLengthError(ValueError):
"""Raise when a batch has items of different lengths."""
class BatchItemsTypeError(TypeError):
"""Raise when a batch has items of different types."""
class BatchDuplicateNodeFieldError(ValueError):
"""Raise when a batch has duplicate node_path and field_name."""
class TooManySessionsError(ValueError):
"""Raise when too many sessions are requested."""
class SessionQueueItemNotFoundError(ValueError):
"""Raise when a queue item is not found."""
# endregion
# region Batch
BatchDataType = Union[
StrictStr,
float,
int,
]
class NodeFieldValue(BaseModel):
node_path: str = Field(description="The node into which this batch data item will be substituted.")
field_name: str = Field(description="The field into which this batch data item will be substituted.")
value: BatchDataType = Field(description="The value to substitute into the node/field.")
class BatchDatum(BaseModel):
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
items: list[BatchDataType] = Field(
default_factory=list, description="The list of items to substitute into the node/field."
)
BatchDataCollection: TypeAlias = list[list[BatchDatum]]
class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with")
runs: int = Field(
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
)
@validator("data")
def validate_lengths(cls, v: Optional[BatchDataCollection]):
if v is None:
return v
for batch_data_list in v:
first_item_length = len(batch_data_list[0].items) if batch_data_list and batch_data_list[0].items else 0
for i in batch_data_list:
if len(i.items) != first_item_length:
raise BatchZippedLengthError("Zipped batch items must all have the same length")
return v
@validator("data")
def validate_types(cls, v: Optional[BatchDataCollection]):
if v is None:
return v
for batch_data_list in v:
for datum in batch_data_list:
# Get the type of the first item in the list
first_item_type = type(datum.items[0]) if datum.items else None
for item in datum.items:
if type(item) is not first_item_type:
raise BatchItemsTypeError("All items in a batch must have the same type")
return v
@validator("data")
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
if v is None:
return v
paths: set[tuple[str, str]] = set()
for batch_data_list in v:
for datum in batch_data_list:
pair = (datum.node_path, datum.field_name)
if pair in paths:
raise BatchDuplicateNodeFieldError("Each batch data must have unique node_id and field_name")
paths.add(pair)
return v
@root_validator(skip_on_failure=True)
def validate_batch_nodes_and_edges(cls, values):
batch_data_collection = cast(Optional[BatchDataCollection], values["data"])
if batch_data_collection is None:
return values
graph = cast(Graph, values["graph"])
for batch_data_list in batch_data_collection:
for batch_data in batch_data_list:
try:
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
except NodeNotFoundError:
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
if batch_data.field_name not in node.__fields__:
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
return values
class Config:
schema_extra = {
"required": [
"graph",
"runs",
]
}
# endregion Batch
# region Queue Items
DEFAULT_QUEUE_ID = "default"
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
field_values_raw = queue_item_dict.get("field_values", None)
return parse_raw_as(list[NodeFieldValue], field_values_raw) if field_values_raw is not None else None
def get_session(queue_item_dict: dict) -> GraphExecutionState:
session_raw = queue_item_dict.get("session", "{}")
return parse_raw_as(GraphExecutionState, session_raw)
class SessionQueueItemWithoutGraph(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
item_id: int = Field(description="The identifier of the session queue item")
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
priority: int = Field(default=0, description="The priority of this queue item")
batch_id: str = Field(description="The ID of the batch associated with this queue item")
session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
)
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
queue_id: str = Field(description="The id of the queue with which this item is associated")
field_values: Optional[list[NodeFieldValue]] = Field(
default=None, description="The field values that were used for this queue item"
)
@classmethod
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
return SessionQueueItemDTO(**queue_item_dict)
class Config:
schema_extra = {
"required": [
"item_id",
"status",
"batch_id",
"queue_id",
"session_id",
"priority",
"session_id",
"created_at",
"updated_at",
]
}
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass
class SessionQueueItem(SessionQueueItemWithoutGraph):
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
@classmethod
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
queue_item_dict["session"] = get_session(queue_item_dict)
return SessionQueueItem(**queue_item_dict)
class Config:
schema_extra = {
"required": [
"item_id",
"status",
"batch_id",
"queue_id",
"session_id",
"session",
"priority",
"session_id",
"created_at",
"updated_at",
]
}
# endregion Queue Items
# region Query Results
class SessionQueueStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
item_id: Optional[int] = Field(description="The current queue item id")
batch_id: Optional[str] = Field(description="The current queue item's batch id")
session_id: Optional[str] = Field(description="The current queue item's session id")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
class BatchStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
class EnqueueBatchResult(BaseModel):
queue_id: str = Field(description="The ID of the queue")
enqueued: int = Field(description="The total number of queue items enqueued")
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
priority: int = Field(description="The priority of the enqueued batch")
class EnqueueGraphResult(BaseModel):
enqueued: int = Field(description="The total number of queue items enqueued")
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
priority: int = Field(description="The priority of the enqueued batch")
queue_item: SessionQueueItemDTO = Field(description="The queue item that was enqueued")
class ClearResult(BaseModel):
"""Result of clearing the session queue"""
deleted: int = Field(..., description="Number of queue items deleted")
class PruneResult(ClearResult):
"""Result of pruning the session queue"""
pass
class CancelByBatchIDsResult(BaseModel):
"""Result of canceling by list of batch ids"""
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""
pass
class IsEmptyResult(BaseModel):
"""Result of checking if the session queue is empty"""
is_empty: bool = Field(..., description="Whether the session queue is empty")
class IsFullResult(BaseModel):
"""Result of checking if the session queue is full"""
is_full: bool = Field(..., description="Whether the session queue is full")
# endregion Query Results
# region Util
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
"""
Populates the given graph with the given batch data items.
"""
graph_clone = graph.copy(deep=True)
for item in node_field_values:
node = graph_clone.get_node(item.node_path)
if node is None:
continue
setattr(node, item.field_name, item.value)
graph_clone.update_node(item.node_path, node)
return graph_clone
def create_session_nfv_tuples(
batch: Batch, maximum: int
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue]], None, None]:
"""
Create all graph permutations from the given batch data and graph. Yields tuples
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
that was applied to the graph.
"""
# TODO: Should this be a class method on Batch?
data: list[list[tuple[NodeFieldValue]]] = []
batch_data_collection = batch.data if batch.data is not None else []
for batch_datum_list in batch_data_collection:
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
node_field_values_to_zip: list[list[NodeFieldValue]] = []
for batch_datum in batch_datum_list:
node_field_values = [
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
for item in batch_datum.items
]
node_field_values_to_zip.append(node_field_values)
data.append(list(zip(*node_field_values_to_zip)))
# create generator to yield session,nfv tuples
count = 0
for _ in range(batch.runs):
for d in product(*data):
if count >= maximum:
return
flat_node_field_values = list(chain.from_iterable(d))
graph = populate_graph(batch.graph, flat_node_field_values)
yield (GraphExecutionState(graph=graph), flat_node_field_values)
count += 1
def calc_session_count(batch: Batch) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring
the overhead of actually generating them. Adapted from `create_sessions().
"""
# TODO: Should this be a class method on Batch?
if not batch.data:
return batch.runs
data = []
for batch_datum_list in batch.data:
to_zip = []
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip)))
data_product = list(product(*data))
return len(data_product) * batch.runs
class SessionQueueValueToInsert(NamedTuple):
"""A tuple of values to insert into the session_queue table"""
queue_id: str # queue_id
session: str # session json
session_id: str # session_id
batch_id: str # batch_id
field_values: Optional[str] # field_values json
priority: int # priority
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
values_to_insert: ValuesToInsert = []
for session, field_values in create_session_nfv_tuples(batch, max_new_queue_items):
# sessions must have unique id
session.id = uuid_string()
values_to_insert.append(
SessionQueueValueToInsert(
queue_id, # queue_id
session.json(), # session (json)
session.id, # session_id
batch.batch_id, # batch_id
# must use pydantic_encoder bc field_values is a list of models
json.dumps(field_values, default=pydantic_encoder) if field_values else None, # field_values (json)
priority, # priority
)
)
return values_to_insert
# endregion Util

View File

@@ -1,816 +0,0 @@
import sqlite3
import threading
from typing import Optional, Union, cast
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Graph
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import (
DEFAULT_QUEUE_ID,
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
EnqueueGraphResult,
IsEmptyResult,
IsFullResult,
PruneResult,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatus,
calc_session_count,
prepare_values_to_insert,
)
from invokeai.app.services.shared.models import CursorPaginatedResults
class SqliteSessionQueue(SessionQueueBase):
__invoker: Invoker
__conn: sqlite3.Connection
__cursor: sqlite3.Cursor
__lock: threading.Lock
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
super().__init__()
self.__conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self.__conn.row_factory = sqlite3.Row
self.__cursor = self.__conn.cursor()
self.__lock = lock
self._create_tables()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
event_name = event[1]["event"]
# This was a match statement, but match is not supported on python 3.9
if event_name == "graph_execution_state_complete":
await self._handle_complete_event(event)
elif event_name in ["invocation_error", "session_retrieval_error", "invocation_retrieval_error"]:
await self._handle_error_event(event)
elif event_name == "session_canceled":
await self._handle_cancel_event(event)
return event
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
# When a queue item has an error, we get an error event, then a completed event.
# Mark the queue item completed only if it isn't already marked completed, e.g.
# by a previously-handled error event.
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
except SessionQueueItemNotFoundError:
return
async def _handle_error_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
error = event[1]["data"]["error"]
queue_item = self.get_queue_item(item_id)
# always set to failed if have an error, even if previously the item was marked completed or canceled
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
except SessionQueueItemNotFoundError:
return
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
except SessionQueueItemNotFoundError:
return
def _create_tables(self) -> None:
"""Creates the session queue tables, indicies, and triggers"""
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
started_at DATETIME, -- updated via trigger
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'in_progress'
BEGIN
UPDATE session_queue
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
"""
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
This is necessary because the invoker may have been killed while processing a queue item.
"""
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = 'canceled'
WHERE status = 'in_progress';
"""
)
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(int, self.__cursor.fetchone()[0])
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
self.__cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
enqueue_result = self.enqueue_batch(queue_id=queue_id, batch=Batch(graph=graph), prepend=prepend)
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
AND batch_id = ?
""",
(queue_id, enqueue_result.batch.batch_id),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with batch id {enqueue_result.batch.batch_id}")
return EnqueueGraphResult(
**enqueue_result.dict(),
queue_item=SessionQueueItemDTO.from_dict(dict(result)),
)
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
self.__lock.acquire()
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.get_config().max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = calc_session_count(batch)
values_to_insert = prepare_values_to_insert(
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
if requested_count > enqueued_count:
values_to_insert = values_to_insert[:max_new_queue_items]
self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority)
VALUES (?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
enqueued=enqueued_count,
batch=batch,
priority=priority,
)
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
return None
queue_item = SessionQueueItem.from_dict(dict(result))
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
return None
return SessionQueueItem.from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
return None
return SessionQueueItem.from_dict(dict(result))
def _set_queue_item_status(
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
) -> SessionQueueItem:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = ?, error = ?
WHERE item_id = ?
""",
(status, error, item_id),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
queue_item = self.get_queue_item(item_id)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return IsFullResult(is_full=is_full)
def delete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id=item_id)
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
DELETE FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return queue_item
def clear(self, queue_id: str) -> ClearResult:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT COUNT(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
"""--sql
DELETE
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
try:
where = """--sql
WHERE
queue_id = ?
AND (
status = 'completed'
OR status = 'failed'
OR status = 'canceled'
)
"""
self.__lock.acquire()
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
(queue_id,),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
DELETE
FROM session_queue
{where};
""",
(queue_id,),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["canceled", "failed", "completed"]:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
self.__invoker.services.queue.cancel(queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id,
graph_execution_state_id=queue_item.session_id,
)
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
try:
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
WHERE
queue_id == ?
AND batch_id IN ({placeholders})
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
"""
params = [queue_id] + batch_ids
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
tuple(params),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
{where};
""",
tuple(params),
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.queue.cancel(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByBatchIDsResult(canceled=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id is ?
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
"""
params = [queue_id]
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
tuple(params),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
{where};
""",
tuple(params),
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.queue.cancel(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByQueueIDResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.from_dict(dict(result))
def list_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
try:
item_id = cursor
self.__lock.acquire()
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
self.__cursor.execute(query, params)
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
items = [SessionQueueItemDTO.from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
items.pop()
has_more = True
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueStatus(
queue_id=queue_id,
item_id=current_item.item_id if current_item else None,
session_id=current_item.session_id if current_item else None,
batch_id=current_item.batch_id if current_item else None,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),
completed=counts.get("completed", 0),
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return BatchStatus(
batch_id=batch_id,
queue_id=queue_id,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),
completed=counts.get("completed", 0),
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
)

View File

@@ -1,14 +0,0 @@
from typing import Generic, TypeVar
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
class CursorPaginatedResults(GenericModel, Generic[GenericBaseModel]):
"""Cursor-paginated results"""
limit: int = Field(..., description="Limit of items to get")
has_more: bool = Field(..., description="Whether there are more items available")
items: list[GenericBaseModel] = Field(..., description="Items")

View File

@@ -1,5 +1,5 @@
import sqlite3
import threading
from threading import Lock
from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, parse_raw_as
@@ -12,19 +12,23 @@ sqlite_memory = ":memory:"
class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str
_table_name: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_id_field: str
_lock: threading.Lock
_lock: Lock
def __init__(self, conn: sqlite3.Connection, table_name: str, lock: threading.Lock, id_field: str = "id"):
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
super().__init__()
self._filename = filename
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = lock
self._conn = conn
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._create_table()
@@ -45,7 +49,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item)
parsed = parse_raw_as(item_type, item)
return parsed
def set(self, item: T):
try:

View File

@@ -1,3 +0,0 @@
import threading
lock = threading.Lock()

View File

@@ -1,5 +1,4 @@
import datetime
import uuid
import numpy as np
@@ -22,8 +21,3 @@ SEED_MAX = np.iinfo(np.uint32).max
def get_random_seed():
rng = np.random.default_rng(seed=None)
return int(rng.integers(0, SEED_MAX))
def uuid_string():
res = uuid.uuid4()
return str(res)

View File

@@ -110,9 +110,6 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,

View File

@@ -1,94 +0,0 @@
Copyright (c) 2016-2020 The Inter Project Authors.
"Inter" is trademark of Rasmus Andersson.
https://github.com/rsms/inter
This Font Software is licensed under the SIL Open Font License, Version 1.1.
This license is copied below, and is also available with a FAQ at:
http://scripts.sil.org/OFL
-----------------------------------------------------------
SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007
-----------------------------------------------------------
PREAMBLE
The goals of the Open Font License (OFL) are to stimulate worldwide
development of collaborative font projects, to support the font creation
efforts of academic and linguistic communities, and to provide a free and
open framework in which fonts may be shared and improved in partnership
with others.
The OFL allows the licensed fonts to be used, studied, modified and
redistributed freely as long as they are not sold by themselves. The
fonts, including any derivative works, can be bundled, embedded,
redistributed and/or sold with any software provided that any reserved
names are not used by derivative works. The fonts and derivatives,
however, cannot be released under any other type of license. The
requirement for fonts to remain under this license does not apply
to any document created using the fonts or their derivatives.
DEFINITIONS
"Font Software" refers to the set of files released by the Copyright
Holder(s) under this license and clearly marked as such. This may
include source files, build scripts and documentation.
"Reserved Font Name" refers to any names specified as such after the
copyright statement(s).
"Original Version" refers to the collection of Font Software components as
distributed by the Copyright Holder(s).
"Modified Version" refers to any derivative made by adding to, deleting,
or substituting -- in part or in whole -- any of the components of the
Original Version, by changing formats or by porting the Font Software to a
new environment.
"Author" refers to any designer, engineer, programmer, technical
writer or other person who contributed to the Font Software.
PERMISSION AND CONDITIONS
Permission is hereby granted, free of charge, to any person obtaining
a copy of the Font Software, to use, study, copy, merge, embed, modify,
redistribute, and sell modified and unmodified copies of the Font
Software, subject to the following conditions:
1) Neither the Font Software nor any of its individual components,
in Original or Modified Versions, may be sold by itself.
2) Original or Modified Versions of the Font Software may be bundled,
redistributed and/or sold with any software, provided that each copy
contains the above copyright notice and this license. These can be
included either as stand-alone text files, human-readable headers or
in the appropriate machine-readable metadata fields within text or
binary files as long as those fields can be easily viewed by the user.
3) No Modified Version of the Font Software may use the Reserved Font
Name(s) unless explicit written permission is granted by the corresponding
Copyright Holder. This restriction only applies to the primary font name as
presented to the users.
4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font
Software shall not be used to promote, endorse or advertise any
Modified Version, except to acknowledge the contribution(s) of the
Copyright Holder(s) and the Author(s) or with their explicit written
permission.
5) The Font Software, modified or unmodified, in part or in whole,
must be distributed entirely under this license, and must not be
distributed under any other license. The requirement for fonts to
remain under this license does not apply to any document created
using the Font Software.
TERMINATION
This license becomes null and void if any of the above conditions are
not met.
DISCLAIMER
THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT
OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE
COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL
DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM
OTHER DEALINGS IN THE FONT SOFTWARE.

View File

@@ -1,46 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""Very simple functions to fetch and print metadata from InvokeAI-generated images."""
import json
import sys
from pathlib import Path
from typing import Any, Dict
from PIL import Image
def get_invokeai_metadata(image_path: Path) -> Dict[str, Any]:
"""
Retrieve "invokeai_metadata" field from png image.
:param image_path: Path to the image to read metadata from.
May raise:
OSError -- image path not found
KeyError -- image doesn't contain the metadata field
"""
image: Image = Image.open(image_path)
return json.loads(image.text["invokeai_metadata"])
def print_invokeai_metadata(image_path: Path):
"""Pretty-print the metadata."""
try:
metadata = get_invokeai_metadata(image_path)
print(f"{image_path}:\n{json.dumps(metadata, sort_keys=True, indent=4)}")
except OSError:
print(f"{image_path}:\nNo file found.")
except KeyError:
print(f"{image_path}:\nNo metadata found.")
print()
def main():
"""Run the command-line utility."""
image_paths = sys.argv[1:]
if not image_paths:
print(f"Usage: {Path(sys.argv[0]).name} image1 image2 image3 ...")
print("\nPretty-print InvokeAI image metadata from the listed png files.")
sys.exit(-1)
for img in image_paths:
print_invokeai_metadata(img)

View File

@@ -70,6 +70,7 @@ def get_literal_fields(field) -> list[Any]:
config = InvokeAIAppConfig.get_config()
Model_dir = "models"
Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path
@@ -92,7 +93,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# or renaming it and then running invokeai-configure again.
"""
logger = InvokeAILogger.get_logger()
logger = InvokeAILogger.getLogger()
class DummyWidgetValue(Enum):
@@ -457,7 +458,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
)
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.",
begin_entry_at=0,
editable=False,
color="CONTROL",
@@ -650,19 +651,8 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
return editApp.new_opts()
def default_ramcache() -> float:
"""Run a heuristic for the default RAM cache based on installed RAM."""
# Note that on my 64 GB machine, psutil.virtual_memory().total gives 62 GB,
# So we adjust everthing down a bit.
return (
15.0 if MAX_RAM >= 60 else 7.5 if MAX_RAM >= 30 else 4 if MAX_RAM >= 14 else 2.1
) # 2.1 is just large enough for sd 1.5 ;-)
def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIAppConfig.get_config()
opts.ram = default_ramcache()
return opts
@@ -904,7 +894,7 @@ def main():
if opt.full_precision:
invoke_args.extend(["--precision", "float32"])
config.parse_args(invoke_args)
logger = InvokeAILogger().get_logger(config=config)
logger = InvokeAILogger().getLogger(config=config)
errors = set()

View File

@@ -2,7 +2,6 @@
Utility (backend) functions used by model_install.py
"""
import os
import re
import shutil
import warnings
from dataclasses import dataclass, field
@@ -31,7 +30,7 @@ warnings.filterwarnings("ignore")
# --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(name="InvokeAI")
logger = InvokeAILogger.getLogger(name="InvokeAI")
# the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
@@ -48,14 +47,8 @@ Config_preamble = """
LEGACY_CONFIGS = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml",
},
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
@@ -76,6 +69,14 @@ LEGACY_CONFIGS = {
}
@dataclass
class ModelInstallList:
"""Class for listing models to be installed/removed"""
install_models: List[str] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list)
@dataclass
class InstallSelections:
install_models: List[str] = field(default_factory=list)
@@ -89,12 +90,10 @@ class ModelLoadInfo:
base_type: BaseModelType
path: Optional[Path] = None
repo_id: Optional[str] = None
subfolder: Optional[str] = None
description: str = ""
installed: bool = False
recommended: bool = False
default: bool = False
requires: Optional[List[str]] = field(default_factory=list)
class ModelInstall(object):
@@ -128,13 +127,12 @@ class ModelInstall(object):
value["name"] = name
value["base_type"] = base
value["model_type"] = model_type
model_info = ModelLoadInfo(**value)
if model_info.subfolder and model_info.repo_id:
model_info.repo_id += f":{model_info.subfolder}"
model_dict[key] = model_info
model_dict[key] = ModelLoadInfo(**value)
# supplement with entries in models.yaml
installed_models = [x for x in self.mgr.list_models()]
# suppresses autoloaded models
# installed_models = [x for x in self.mgr.list_models() if not self._is_autoloaded(x)]
for md in installed_models:
base = md["base_model"]
@@ -166,12 +164,9 @@ class ModelInstall(object):
def list_models(self, model_type):
installed = self.mgr.list_models(model_type=model_type)
print()
print(f"Installed models of type `{model_type}`:")
print(f"{'Model Key':50} Model Path")
for i in installed:
print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}")
print()
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
# logic here a little reversed to maintain backward compatibility
def starter_models(self, all_models: bool = False) -> Set[str]:
@@ -209,8 +204,6 @@ class ModelInstall(object):
job += 1
# add requested models
self._remove_installed(selections.install_models)
self._add_required_models(selections.install_models)
for path in selections.install_models:
logger.info(f"Installing {path} [{job}/{jobs}]")
try:
@@ -270,26 +263,6 @@ class ModelInstall(object):
return models_installed
def _remove_installed(self, model_list: List[str]):
all_models = self.all_models()
for path in model_list:
key = self.reverse_paths.get(path)
if key and all_models[key].installed:
logger.warning(f"{path} already installed. Skipping.")
model_list.remove(path)
def _add_required_models(self, model_list: List[str]):
additional_models = []
all_models = self.all_models()
for path in model_list:
if not (key := self.reverse_paths.get(path)):
continue
for requirement in all_models[key].requires:
requirement_key = self.reverse_paths.get(requirement)
if not all_models[requirement_key].installed:
additional_models.append(requirement)
model_list.extend(additional_models)
# install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
@@ -313,7 +286,7 @@ class ModelInstall(object):
location = download_with_resume(url, Path(staging))
if not location:
logger.error(f"Unable to download {url}. Skipping.")
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
info = ModelProbe().heuristic_probe(location)
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
dest.parent.mkdir(parents=True, exist_ok=True)
models_path = shutil.move(location, dest)
@@ -322,63 +295,46 @@ class ModelInstall(object):
return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str) -> AddModelResult:
# hack to recover models stored in subfolders --
# Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster
subfolder = None
if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id):
repo_id = match.group(1)
subfolder = match.group(2)
hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically
# list all the files in the repo
files = [x.rfilename for x in hinfo.siblings]
if subfolder:
files = [x for x in files if x.startswith("v2/")]
prefix = f"{subfolder}/" if subfolder else ""
location = None
with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging)
if f"{prefix}model_index.json" in files:
location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline
elif f"{prefix}unet/model.onnx" in files:
if "model_index.json" in files:
location = self._download_hf_pipeline(repo_id, staging) # pipeline
elif "unet/model.onnx" in files:
location = self._download_hf_model(repo_id, files, staging)
else:
for suffix in ["safetensors", "bin"]:
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
location = self._download_hf_model(
repo_id, ["pytorch_lora_weights.bin"], staging, subfolder=subfolder
) # LoRA
if f"pytorch_lora_weights.{suffix}" in files:
location = self._download_hf_model(repo_id, ["pytorch_lora_weights.bin"], staging) # LoRA
break
elif (
self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files
self.config.precision == "float16" and f"diffusion_pytorch_model.fp16.{suffix}" in files
): # vae, controlnet or some other standalone
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
location = self._download_hf_model(repo_id, files, staging)
break
elif f"{prefix}diffusion_pytorch_model.{suffix}" in files:
elif f"diffusion_pytorch_model.{suffix}" in files:
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
location = self._download_hf_model(repo_id, files, staging)
break
elif f"{prefix}learned_embeds.{suffix}" in files:
location = self._download_hf_model(
repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder
)
elif f"learned_embeds.{suffix}" in files:
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
break
elif (
f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files
): # IP-Adapter
elif "image_encoder.txt" in files and f"ip_adapter.{suffix}" in files: # IP-Adapter
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
location = self._download_hf_model(repo_id, files, staging)
break
elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files:
elif f"model.{suffix}" in files and "config.json" in files:
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
# by InvokeAI for use with IP-Adapters.
files = ["config.json", f"model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
location = self._download_hf_model(repo_id, files, staging)
break
if not location:
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
@@ -437,7 +393,7 @@ class ModelInstall(object):
possible_conf = path.with_suffix(".yaml")
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
elif info.base_type == BaseModelType.StableDiffusion2:
legacy_conf = Path(
self.config.legacy_conf_dir,
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
@@ -465,9 +421,9 @@ class ModelInstall(object):
else:
return path
def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path:
def _download_hf_pipeline(self, repo_id: str, staging: Path) -> Path:
"""
Retrieve a StableDiffusion model from cache or remote and then
This retrieves a StableDiffusion model from cache or remote and then
does a save_pretrained() to the indicated staging area.
"""
_, name = repo_id.split("/")
@@ -482,7 +438,6 @@ class ModelInstall(object):
variant=variant,
torch_dtype=precision,
safety_checker=None,
subfolder=subfolder,
)
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
if "fp16" not in str(e):
@@ -497,7 +452,7 @@ class ModelInstall(object):
model.save_pretrained(staging / name, safe_serialization=True)
return staging / name
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path) -> Path:
_, name = repo_id.split("/")
location = staging / name
paths = list()
@@ -508,7 +463,7 @@ class ModelInstall(object):
model_dir=location / filePath.parent,
model_name=filePath.name,
access_token=self.access_token,
subfolder=filePath.parent / subfolder if subfolder else filePath.parent,
subfolder=filePath.parent,
)
if p:
paths.append(p)
@@ -537,7 +492,7 @@ def yes_or_no(prompt: str, default_yes=True):
# ---------------------------------------------
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
logger = InvokeAILogger.get_logger("InvokeAI")
logger = InvokeAILogger.getLogger("InvokeAI")
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
model = model_class.from_pretrained(

View File

@@ -42,4 +42,4 @@ IP-Adapters:
- [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
- [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
- [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
- [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)
- Not yet supported: [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)

View File

@@ -9,8 +9,6 @@ from diffusers.models import UNet2DConditionModel
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.model_management.models.base import calc_model_size_by_data
from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from .resampler import Resampler
@@ -89,20 +87,6 @@ class IPAdapter:
if self._attn_processors is not None:
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
def calc_size(self):
if self._state_dict is not None:
image_proj_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["image_proj"].values()]
)
ip_adapter_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["ip_adapter"].values()]
)
return image_proj_size + ip_adapter_size
else:
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(
torch.nn.ModuleList(self._attn_processors.values())
)
def _init_image_proj_model(self, state_dict):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)

View File

@@ -74,7 +74,7 @@ if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
logger = InvokeAILogger.get_logger(__name__)
logger = InvokeAILogger.getLogger(__name__)
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert"
@@ -1279,12 +1279,12 @@ def download_from_original_stable_diffusion_ckpt(
extract_ema = original_config["model"]["params"]["use_ema"]
if (
model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1]
model_version == BaseModelType.StableDiffusion2
and original_config["model"]["params"].get("parameterization") == "v"
):
prediction_type = "v_prediction"
upcast_attention = True
image_size = 768 if model_version == BaseModelType.StableDiffusion2 else 512
image_size = 768
else:
prediction_type = "epsilon"
upcast_attention = False

View File

@@ -1,5 +1,4 @@
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Literal, Optional, Union
@@ -54,7 +53,6 @@ class ModelProbe(object):
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
}
@@ -90,7 +88,8 @@ class ModelProbe(object):
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The prediction_type_helper callable is a function that receives
the path to the model and returns the SchedulerPredictionType.
the path to the model and returns the BaseModelType. It is called to distinguish
between V2-Base and V2-768 SD models.
"""
if model_path:
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
@@ -178,7 +177,6 @@ class ModelProbe(object):
Get the model type of a hugging-face style folder.
"""
class_name = None
error_hint = None
if model:
class_name = model.__class__.__name__
else:
@@ -204,18 +202,12 @@ class ModelProbe(object):
class_name = conf["architectures"][0]
else:
class_name = None
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up
raise InvalidModelException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
@@ -304,36 +296,25 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
else:
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
"""Return model prediction type."""
# if there is a .yaml associated with this checkpoint, then we do not need
# to probe for the prediction type as it will be ignored.
if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists():
return None
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
type = self.get_base_type()
if type == BaseModelType.StableDiffusion2:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in checkpoint:
if checkpoint["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
if self.helper and self.checkpoint_path:
if helper_guess := self.helper(self.checkpoint_path):
return helper_guess
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
elif type == BaseModelType.StableDiffusion1:
if self.helper and self.checkpoint_path:
if helper_guess := self.helper(self.checkpoint_path):
return helper_guess
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
else:
return None
if type == BaseModelType.StableDiffusion1:
return SchedulerPredictionType.Epsilon
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in checkpoint:
if checkpoint["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
if (
self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists()
): # if a .yaml config file exists, then this step not needed
return self.helper(self.checkpoint_path)
else:
return None
class VaeCheckpointProbe(CheckpointProbeBase):
@@ -480,32 +461,16 @@ class PipelineFolderProbe(FolderProbeBase):
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _guess_name(self) -> str:
name = self.folder_path.name
if name == "vae":
name = self.folder_path.parent.name
return name
return (
BaseModelType.StableDiffusionXL
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
else BaseModelType.StableDiffusion1
)
class TextualInversionFolderProbe(FolderProbeBase):

View File

@@ -71,13 +71,7 @@ class ModelSearch(ABC):
if any(
[
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
}
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
]
):
try:

View File

@@ -13,7 +13,6 @@ from invokeai.backend.model_management.models.base import (
ModelConfigBase,
ModelType,
SubModelType,
calc_model_size_by_fs,
classproperty,
)
@@ -31,7 +30,7 @@ class IPAdapterModel(ModelBase):
assert model_type == ModelType.IPAdapter
super().__init__(model_path, base_model, model_type)
self.model_size = calc_model_size_by_fs(self.model_path)
self.model_size = os.path.getsize(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
@@ -64,13 +63,10 @@ class IPAdapterModel(ModelBase):
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
model = build_ip_adapter(
return build_ip_adapter(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
)
self.model_size = model.calc_size()
return model
@classmethod
def convert_if_required(
cls,

View File

@@ -25,71 +25,55 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
def set_seamless(
model: Union[UNet2DConditionModel, AutoencoderKL],
axes: List[str],
skipped_layers: int,
skip_second_resnet: bool,
skip_conv2: bool,
):
try:
to_restore = []
for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel):
if ".attentions." in m_name:
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
continue
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name:
# down_blocks.1.resnets.1.conv1
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
block_num = int(block_num)
resnet_num = int(resnet_num)
# if block_num >= seamless_down_blocks:
if block_num >= len(model.down_blocks) - skipped_layers:
continue
if ".resnets." in m_name:
if ".conv2" in m_name:
continue
if ".conv_shortcut" in m_name:
continue
"""
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
if resnet_num > 0 and skip_second_resnet:
continue
if False and ".downsamplers." in m_name:
if submodule_name == "conv2" and skip_conv2:
continue
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield

View File

@@ -1,568 +0,0 @@
# pylint: disable=line-too-long
# pylint: disable=broad-exception-caught
# pylint: disable=missing-function-docstring
"""Script to peform db maintenance and outputs directory management."""
import argparse
import datetime
import enum
import glob
import locale
import os
import shutil
import sqlite3
from pathlib import Path
import PIL
import PIL.ImageOps
import PIL.PngImagePlugin
import yaml
class ConfigMapper:
"""Configuration loader."""
def __init__(self): # noqa D107
pass
TIMESTAMP_STRING = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
INVOKE_DIRNAME = "invokeai"
YAML_FILENAME = "invokeai.yaml"
DATABASE_FILENAME = "invokeai.db"
database_path = None
database_backup_dir = None
outputs_path = None
archive_path = None
thumbnails_path = None
thumbnails_archive_path = None
def load(self):
"""Read paths from yaml config and validate."""
root = "."
if not self.__load_from_root_config(os.path.abspath(root)):
return False
return True
def __load_from_root_config(self, invoke_root):
"""Validate a yaml path exists, confirm the user wants to use it and load config."""
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
if os.path.exists(yaml_path):
db_dir, outdir = self.__load_paths_from_yaml_file(yaml_path)
if db_dir is None or outdir is None:
print("The invokeai.yaml file was found but is missing the db_dir and/or outdir setting!")
return False
if os.path.isabs(db_dir):
self.database_path = os.path.join(db_dir, self.DATABASE_FILENAME)
else:
self.database_path = os.path.join(invoke_root, db_dir, self.DATABASE_FILENAME)
self.database_backup_dir = os.path.join(os.path.dirname(self.database_path), "backup")
if os.path.isabs(outdir):
self.outputs_path = os.path.join(outdir, "images")
self.archive_path = os.path.join(outdir, "images-archive")
else:
self.outputs_path = os.path.join(invoke_root, outdir, "images")
self.archive_path = os.path.join(invoke_root, outdir, "images-archive")
self.thumbnails_path = os.path.join(self.outputs_path, "thumbnails")
self.thumbnails_archive_path = os.path.join(self.archive_path, "thumbnails")
db_exists = os.path.exists(self.database_path)
outdir_exists = os.path.exists(self.outputs_path)
text = f"Found {self.YAML_FILENAME} file at {yaml_path}:"
text += f"\n Database : {self.database_path} - {'Exists!' if db_exists else 'Not Found!'}"
text += f"\n Outputs : {self.outputs_path}- {'Exists!' if outdir_exists else 'Not Found!'}"
print(text)
if db_exists and outdir_exists:
return True
else:
print(
"\nOne or more paths specified in invoke.yaml do not exist. Please inspect/correct the configuration and ensure the script is run in the developer console mode (option 8) from an Invoke AI root directory."
)
return False
else:
print(
f"Auto-discovery of configuration failed! Could not find ({yaml_path})!\n\nPlease ensure the script is run in the developer console mode (option 8) from an Invoke AI root directory."
)
return False
def __load_paths_from_yaml_file(self, yaml_path):
"""Load an Invoke AI yaml file and get the database and outputs paths."""
try:
with open(yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
yamlinfo = yaml.safe_load(file)
db_dir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("db_dir", None)
outdir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("outdir", None)
return db_dir, outdir
except Exception:
print(f"Failed to load paths from yaml file! {yaml_path}!")
return None, None
class MaintenanceStats:
"""DTO for tracking work progress."""
def __init__(self): # noqa D107
pass
time_start = datetime.datetime.utcnow()
count_orphaned_db_entries_cleaned = 0
count_orphaned_disk_files_cleaned = 0
count_orphaned_thumbnails_cleaned = 0
count_thumbnails_regenerated = 0
count_errors = 0
@staticmethod
def get_elapsed_time_string():
"""Get a friendly time string for the time elapsed since processing start."""
time_now = datetime.datetime.utcnow()
total_seconds = (time_now - MaintenanceStats.time_start).total_seconds()
hours = int((total_seconds) / 3600)
minutes = int(((total_seconds) % 3600) / 60)
seconds = total_seconds % 60
out_str = f"{hours} hour(s) -" if hours > 0 else ""
out_str += f"{minutes} minute(s) -" if minutes > 0 else ""
out_str += f"{seconds:.2f} second(s)"
return out_str
class DatabaseMapper:
"""Class to abstract database functionality."""
def __init__(self, database_path, database_backup_dir): # noqa D107
self.database_path = database_path
self.database_backup_dir = database_backup_dir
self.connection = None
self.cursor = None
def backup(self, timestamp_string):
"""Take a backup of the database."""
if not os.path.exists(self.database_backup_dir):
print(f"Database backup directory {self.database_backup_dir} does not exist -> creating...", end="")
os.makedirs(self.database_backup_dir)
print("Done!")
database_backup_path = os.path.join(self.database_backup_dir, f"backup-{timestamp_string}-invokeai.db")
print(f"Making DB Backup at {database_backup_path}...", end="")
shutil.copy2(self.database_path, database_backup_path)
print("Done!")
def connect(self):
"""Open connection to the database."""
self.connection = sqlite3.connect(self.database_path)
self.cursor = self.connection.cursor()
def get_all_image_files(self):
"""Get the full list of image file names from the database."""
sql_get_image_by_name = "SELECT image_name FROM images"
self.cursor.execute(sql_get_image_by_name)
rows = self.cursor.fetchall()
db_files = []
for row in rows:
db_files.append(row[0])
return db_files
def remove_image_file_record(self, filename: str):
"""Remove an image file reference from the database by filename."""
sanitized_filename = str.replace(filename, "'", "''") # prevent injection
sql_command = f"DELETE FROM images WHERE image_name='{sanitized_filename}'"
self.cursor.execute(sql_command)
self.connection.commit()
def does_image_exist(self, image_filename):
"""Check database if a image name already exists and return a boolean."""
sanitized_filename = str.replace(image_filename, "'", "''") # prevent injection
sql_get_image_by_name = f"SELECT image_name FROM images WHERE image_name='{sanitized_filename}'"
self.cursor.execute(sql_get_image_by_name)
rows = self.cursor.fetchall()
return True if len(rows) > 0 else False
def disconnect(self):
"""Disconnect from the db, cleaning up connections and cursors."""
if self.cursor is not None:
self.cursor.close()
if self.connection is not None:
self.connection.close()
class PhysicalFileMapper:
"""Containing class for script functionality."""
def __init__(self, outputs_path, thumbnails_path, archive_path, thumbnails_archive_path): # noqa D107
self.outputs_path = outputs_path
self.archive_path = archive_path
self.thumbnails_path = thumbnails_path
self.thumbnails_archive_path = thumbnails_archive_path
def create_archive_directories(self):
"""Create the directory for archiving orphaned image files."""
if not os.path.exists(self.archive_path):
print(f"Image archive directory ({self.archive_path}) does not exist -> creating...", end="")
os.makedirs(self.archive_path)
print("Created!")
if not os.path.exists(self.thumbnails_archive_path):
print(
f"Image thumbnails archive directory ({self.thumbnails_archive_path}) does not exist -> creating...",
end="",
)
os.makedirs(self.thumbnails_archive_path)
print("Created!")
def get_image_path_for_image_name(self, image_filename): # noqa D102
return os.path.join(self.outputs_path, image_filename)
def image_file_exists(self, image_filename): # noqa D102
return os.path.exists(self.get_image_path_for_image_name(image_filename))
def get_thumbnail_path_for_image(self, image_filename): # noqa D102
return os.path.join(self.thumbnails_path, os.path.splitext(image_filename)[0]) + ".webp"
def get_image_name_from_thumbnail_path(self, thumbnail_path): # noqa D102
return os.path.splitext(os.path.basename(thumbnail_path))[0] + ".png"
def thumbnail_exists_for_filename(self, image_filename): # noqa D102
return os.path.exists(self.get_thumbnail_path_for_image(image_filename))
def archive_image(self, image_filename): # noqa D102
if self.image_file_exists(image_filename):
image_path = self.get_image_path_for_image_name(image_filename)
shutil.move(image_path, self.archive_path)
def archive_thumbnail_by_image_filename(self, image_filename): # noqa D102
if self.thumbnail_exists_for_filename(image_filename):
thumbnail_path = self.get_thumbnail_path_for_image(image_filename)
shutil.move(thumbnail_path, self.thumbnails_archive_path)
def get_all_png_filenames_in_directory(self, directory_path): # noqa D102
filepaths = glob.glob(directory_path + "/*.png", recursive=False)
filenames = []
for filepath in filepaths:
filenames.append(os.path.basename(filepath))
return filenames
def get_all_thumbnails_with_full_path(self, thumbnails_directory): # noqa D102
return glob.glob(thumbnails_directory + "/*.webp", recursive=False)
def generate_thumbnail_for_image_name(self, image_filename): # noqa D102
# create thumbnail
file_path = self.get_image_path_for_image_name(image_filename)
thumb_path = self.get_thumbnail_path_for_image(image_filename)
thumb_size = 256, 256
with PIL.Image.open(file_path) as source_image:
source_image.thumbnail(thumb_size)
source_image.save(thumb_path, "webp")
class MaintenanceOperation(str, enum.Enum):
"""Enum class for operations."""
Ask = "ask"
CleanOrphanedDbEntries = "clean"
CleanOrphanedDiskFiles = "archive"
ReGenerateThumbnails = "thumbnails"
All = "all"
class InvokeAIDatabaseMaintenanceApp:
"""Main processor class for the application."""
_operation: MaintenanceOperation
_headless: bool = False
__stats: MaintenanceStats = MaintenanceStats()
def __init__(self, operation: MaintenanceOperation = MaintenanceOperation.Ask):
"""Initialize maintenance app."""
self._operation = MaintenanceOperation(operation)
self._headless = operation != MaintenanceOperation.Ask
def ask_for_operation(self) -> MaintenanceOperation:
"""Ask user to choose the operation to perform."""
while True:
print()
print("It is recommennded to run these operations as ordered below to avoid additional")
print("work being performed that will be discarded in a subsequent step.")
print()
print("Select maintenance operation:")
print()
print("1) Clean Orphaned Database Image Entries")
print(" Cleans entries in the database where the matching file was removed from")
print(" the outputs directory.")
print("2) Archive Orphaned Image Files")
print(" Files found in the outputs directory without an entry in the database are")
print(" moved to an archive directory.")
print("3) Re-Generate Missing Thumbnail Files")
print(" For files found in the outputs directory, re-generate a thumbnail if it")
print(" not found in the thumbnails directory.")
print()
print("(CTRL-C to quit)")
try:
input_option = int(input("Specify desired operation number (1-3): "))
operations = [
MaintenanceOperation.CleanOrphanedDbEntries,
MaintenanceOperation.CleanOrphanedDiskFiles,
MaintenanceOperation.ReGenerateThumbnails,
]
return operations[input_option - 1]
except (IndexError, ValueError):
print("\nInvalid selection!")
def ask_to_continue(self) -> bool:
"""Ask user whether they want to continue with the operation."""
while True:
input_choice = input("Do you wish to continue? (Y or N)? ")
if str.lower(input_choice) == "y":
return True
if str.lower(input_choice) == "n":
return False
def clean_orphaned_db_entries(
self, config: ConfigMapper, file_mapper: PhysicalFileMapper, db_mapper: DatabaseMapper
):
"""Clean dangling database entries that no longer point to a file in outputs."""
if self._headless:
print(f"Removing database references to images that no longer exist in {config.outputs_path}...")
else:
print()
print("===============================================================================")
print("= Clean Orphaned Database Entries")
print()
print("Perform this operation if you have removed files from the outputs/images")
print("directory but the database was never updated. You may see this as empty imaages")
print("in the app gallery, or images that only show an enlarged version of the")
print("thumbnail.")
print()
print(f"Database File Path : {config.database_path}")
print(f"Database backup will be taken at : {config.database_backup_dir}")
print(f"Outputs/Images Directory : {config.outputs_path}")
print(f"Outputs/Images Archive Directory : {config.archive_path}")
print("\nNotes about this operation:")
print("- This operation will find database image file entries that do not exist in the")
print(" outputs/images dir and remove those entries from the database.")
print("- This operation will target all image types including intermediate files.")
print("- If a thumbnail still exists in outputs/images/thumbnails matching the")
print(" orphaned entry, it will be moved to the archive directory.")
print()
if not self.ask_to_continue():
raise KeyboardInterrupt
file_mapper.create_archive_directories()
db_mapper.backup(config.TIMESTAMP_STRING)
db_mapper.connect()
db_files = db_mapper.get_all_image_files()
for db_file in db_files:
try:
if not file_mapper.image_file_exists(db_file):
print(f"Found orphaned image db entry {db_file}. Cleaning ...", end="")
db_mapper.remove_image_file_record(db_file)
print("Cleaned!")
if file_mapper.thumbnail_exists_for_filename(db_file):
print("A thumbnail was found, archiving ...", end="")
file_mapper.archive_thumbnail_by_image_filename(db_file)
print("Archived!")
self.__stats.count_orphaned_db_entries_cleaned += 1
except Exception as ex:
print("An error occurred cleaning db entry, error was:")
print(ex)
self.__stats.count_errors += 1
def clean_orphaned_disk_files(
self, config: ConfigMapper, file_mapper: PhysicalFileMapper, db_mapper: DatabaseMapper
):
"""Archive image files that no longer have entries in the database."""
if self._headless:
print(f"Archiving orphaned image files to {config.archive_path}...")
else:
print()
print("===============================================================================")
print("= Clean Orphaned Disk Files")
print()
print("Perform this operation if you have files that were copied into the outputs")
print("directory which are not referenced by the database. This can happen if you")
print("upgraded to a version with a fresh database, but re-used the outputs directory")
print("and now new images are mixed with the files not in the db. The script will")
print("archive these files so you can choose to delete them or re-import using the")
print("official import script.")
print()
print(f"Database File Path : {config.database_path}")
print(f"Database backup will be taken at : {config.database_backup_dir}")
print(f"Outputs/Images Directory : {config.outputs_path}")
print(f"Outputs/Images Archive Directory : {config.archive_path}")
print("\nNotes about this operation:")
print("- This operation will find image files not referenced by the database and move to an")
print(" archive directory.")
print("- This operation will target all image types including intermediate references.")
print("- The matching thumbnail will also be archived.")
print("- Any remaining orphaned thumbnails will also be archived.")
if not self.ask_to_continue():
raise KeyboardInterrupt
print()
file_mapper.create_archive_directories()
db_mapper.backup(config.TIMESTAMP_STRING)
db_mapper.connect()
phys_files = file_mapper.get_all_png_filenames_in_directory(config.outputs_path)
for phys_file in phys_files:
try:
if not db_mapper.does_image_exist(phys_file):
print(f"Found orphaned file {phys_file}, archiving...", end="")
file_mapper.archive_image(phys_file)
print("Archived!")
if file_mapper.thumbnail_exists_for_filename(phys_file):
print("Related thumbnail exists, archiving...", end="")
file_mapper.archive_thumbnail_by_image_filename(phys_file)
print("Archived!")
else:
print("No matching thumbnail existed to be cleaned.")
self.__stats.count_orphaned_disk_files_cleaned += 1
except Exception as ex:
print("Error found trying to archive file or thumbnail, error was:")
print(ex)
self.__stats.count_errors += 1
thumb_filepaths = file_mapper.get_all_thumbnails_with_full_path(config.thumbnails_path)
# archive any remaining orphaned thumbnails
for thumb_filepath in thumb_filepaths:
try:
thumb_src_image_name = file_mapper.get_image_name_from_thumbnail_path(thumb_filepath)
if not file_mapper.image_file_exists(thumb_src_image_name):
print(f"Found orphaned thumbnail {thumb_filepath}, archiving...", end="")
file_mapper.archive_thumbnail_by_image_filename(thumb_src_image_name)
print("Archived!")
self.__stats.count_orphaned_thumbnails_cleaned += 1
except Exception as ex:
print("Error found trying to archive thumbnail, error was:")
print(ex)
self.__stats.count_errors += 1
def regenerate_thumbnails(self, config: ConfigMapper, file_mapper: PhysicalFileMapper, *args):
"""Create missing thumbnails for any valid general images both in the db and on disk."""
if self._headless:
print("Regenerating missing image thumbnails...")
else:
print()
print("===============================================================================")
print("= Regenerate Thumbnails")
print()
print("This operation will find files that have no matching thumbnail on disk")
print("and regenerate those thumbnail files.")
print("NOTE: It is STRONGLY recommended that the user first clean/archive orphaned")
print(" disk files from the previous menu to avoid wasting time regenerating")
print(" thumbnails for orphaned files.")
print()
print(f"Outputs/Images Directory : {config.outputs_path}")
print(f"Outputs/Images Directory : {config.thumbnails_path}")
print("\nNotes about this operation:")
print("- This operation will find image files both referenced in the db and on disk")
print(" that do not have a matching thumbnail on disk and re-generate the thumbnail")
print(" file.")
if not self.ask_to_continue():
raise KeyboardInterrupt
print()
phys_files = file_mapper.get_all_png_filenames_in_directory(config.outputs_path)
for phys_file in phys_files:
try:
if not file_mapper.thumbnail_exists_for_filename(phys_file):
print(f"Found file without thumbnail {phys_file}...Regenerating Thumbnail...", end="")
file_mapper.generate_thumbnail_for_image_name(phys_file)
print("Done!")
self.__stats.count_thumbnails_regenerated += 1
except Exception as ex:
print("Error found trying to regenerate thumbnail, error was:")
print(ex)
self.__stats.count_errors += 1
def main(self): # noqa D107
print("\n===============================================================================")
print("Database and outputs Maintenance for Invoke AI 3.0.0 +")
print("===============================================================================\n")
config_mapper = ConfigMapper()
if not config_mapper.load():
print("\nInvalid configuration...exiting.\n")
return
file_mapper = PhysicalFileMapper(
config_mapper.outputs_path,
config_mapper.thumbnails_path,
config_mapper.archive_path,
config_mapper.thumbnails_archive_path,
)
db_mapper = DatabaseMapper(config_mapper.database_path, config_mapper.database_backup_dir)
op = self._operation
operations_to_perform = []
if op == MaintenanceOperation.Ask:
op = self.ask_for_operation()
if op in [MaintenanceOperation.CleanOrphanedDbEntries, MaintenanceOperation.All]:
operations_to_perform.append(self.clean_orphaned_db_entries)
if op in [MaintenanceOperation.CleanOrphanedDiskFiles, MaintenanceOperation.All]:
operations_to_perform.append(self.clean_orphaned_disk_files)
if op in [MaintenanceOperation.ReGenerateThumbnails, MaintenanceOperation.All]:
operations_to_perform.append(self.regenerate_thumbnails)
for operation in operations_to_perform:
operation(config_mapper, file_mapper, db_mapper)
print("\n===============================================================================")
print(f"= Maintenance Complete - Elapsed Time: {MaintenanceStats.get_elapsed_time_string()}")
print()
print(f"Orphaned db entries cleaned : {self.__stats.count_orphaned_db_entries_cleaned}")
print(f"Orphaned disk files archived : {self.__stats.count_orphaned_disk_files_cleaned}")
print(f"Orphaned thumbnail files archived : {self.__stats.count_orphaned_thumbnails_cleaned}")
print(f"Thumbnails regenerated : {self.__stats.count_thumbnails_regenerated}")
print(f"Errors during operation : {self.__stats.count_errors}")
print()
def main(): # noqa D107
parser = argparse.ArgumentParser(
description="InvokeAI image database maintenance utility",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""Operations:
ask Choose operation from a menu [default]
all Run all maintenance operations
clean Clean database of dangling entries
archive Archive orphaned image files
thumbnails Regenerate missing image thumbnails
""",
)
parser.add_argument("--root", default=".", type=Path, help="InvokeAI root directory")
parser.add_argument(
"--operation", default="ask", choices=[x.value for x in MaintenanceOperation], help="Operation to perform."
)
args = parser.parse_args()
try:
os.chdir(args.root)
app = InvokeAIDatabaseMaintenanceApp(args.operation)
app.main()
except KeyboardInterrupt:
print("\n\nUser cancelled execution.")
except FileNotFoundError:
print(f"Invalid root directory '{args.root}'.")
if __name__ == "__main__":
main()

View File

@@ -24,7 +24,7 @@ from invokeai.backend.util.logging import InvokeAILogger
# Modified ControlNetModel with encoder_attention_mask argument added
logger = InvokeAILogger.get_logger(__name__)
logger = InvokeAILogger.getLogger(__name__)
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):

View File

@@ -1,6 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""invokeai.backend.util.logging
"""
invokeai.backend.util.logging
Logging class for InvokeAI that produces console messages
@@ -8,9 +9,9 @@ Usage:
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger(name='InvokeAI') // Initialization
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
(or)
logger = InvokeAILogger.get_logger(__name__) // To use the filename
logger = InvokeAILogger.getLogger(__name__) // To use the filename
logger.configure()
logger.critical('this is critical') // Critical Message
@@ -33,13 +34,13 @@ IAILogger.debug('this is a debugging message')
## Configuration
The default configuration will print to stderr on the console. To add
additional logging handlers, call get_logger with an initialized InvokeAIAppConfig
additional logging handlers, call getLogger with an initialized InvokeAIAppConfig
object:
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.get_logger(config=config)
logger = InvokeAILogger.getLogger(config=config)
### Three command-line options control logging:
@@ -172,7 +173,6 @@ InvokeAI:
log_level: info
log_format: color
```
"""
import logging.handlers
@@ -193,35 +193,39 @@ except ImportError:
# module level functions
def debug(msg, *args, **kwargs):
InvokeAILogger.get_logger().debug(msg, *args, **kwargs)
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
def info(msg, *args, **kwargs):
InvokeAILogger.get_logger().info(msg, *args, **kwargs)
InvokeAILogger.getLogger().info(msg, *args, **kwargs)
def warning(msg, *args, **kwargs):
InvokeAILogger.get_logger().warning(msg, *args, **kwargs)
InvokeAILogger.getLogger().warning(msg, *args, **kwargs)
def error(msg, *args, **kwargs):
InvokeAILogger.get_logger().error(msg, *args, **kwargs)
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
def critical(msg, *args, **kwargs):
InvokeAILogger.get_logger().critical(msg, *args, **kwargs)
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
def log(level, msg, *args, **kwargs):
InvokeAILogger.get_logger().log(level, msg, *args, **kwargs)
InvokeAILogger.getLogger().log(level, msg, *args, **kwargs)
def disable(level=logging.CRITICAL):
InvokeAILogger.get_logger().disable(level)
InvokeAILogger.getLogger().disable(level)
def basicConfig(**kwargs):
InvokeAILogger.get_logger().basicConfig(**kwargs)
InvokeAILogger.getLogger().basicConfig(**kwargs)
def getLogger(name: str = None) -> logging.Logger:
return InvokeAILogger.getLogger(name)
_FACILITY_MAP = (
@@ -347,7 +351,7 @@ class InvokeAILogger(object):
loggers = dict()
@classmethod
def get_logger(
def getLogger(
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
) -> logging.Logger:
if name in cls.loggers:
@@ -356,13 +360,13 @@ class InvokeAILogger(object):
else:
logger = logging.getLogger(name)
logger.setLevel(config.log_level.upper()) # yes, strings work here
for ch in cls.get_loggers(config):
for ch in cls.getLoggers(config):
logger.addHandler(ch)
cls.loggers[name] = logger
return cls.loggers[name]
@classmethod
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
def getLoggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
handler_strs = config.log_handlers
handlers = list()
for handler in handler_strs:

View File

@@ -60,9 +60,6 @@ sd-1/main/trinart_stable_diffusion_v2:
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
repo_id: naclbit/trinart_stable_diffusion_v2
recommended: False
sd-1/controlnet/qrcode_monster:
repo_id: monster-labs/control_v1p_sd15_qrcode_monster
subfolder: v2
sd-1/controlnet/canny:
repo_id: lllyasviel/control_v11p_sd15_canny
recommended: True
@@ -106,35 +103,3 @@ sd-1/lora/LowRA:
recommended: True
sd-1/lora/Ink scenery:
path: https://civitai.com/api/download/models/83390
sd-1/ip_adapter/ip_adapter_sd15:
repo_id: InvokeAI/ip_adapter_sd15
recommended: True
requires:
- InvokeAI/ip_adapter_sd_image_encoder
description: IP-Adapter for SD 1.5 models
sd-1/ip_adapter/ip_adapter_plus_sd15:
repo_id: InvokeAI/ip_adapter_plus_sd15
recommended: False
requires:
- InvokeAI/ip_adapter_sd_image_encoder
description: Refined IP-Adapter for SD 1.5 models
sd-1/ip_adapter/ip_adapter_plus_face_sd15:
repo_id: InvokeAI/ip_adapter_plus_face_sd15
recommended: False
requires:
- InvokeAI/ip_adapter_sd_image_encoder
description: Refined IP-Adapter for SD 1.5 models, adapted for faces
sdxl/ip_adapter/ip_adapter_sdxl:
repo_id: InvokeAI/ip_adapter_sdxl
recommended: False
requires:
- InvokeAI/ip_adapter_sdxl_image_encoder
description: IP-Adapter for SDXL models
any/clip_vision/ip_adapter_sd_image_encoder:
repo_id: InvokeAI/ip_adapter_sd_image_encoder
recommended: False
description: Required model for using IP-Adapters with SD-1/2 models
any/clip_vision/ip_adapter_sdxl_image_encoder:
repo_id: InvokeAI/ip_adapter_sdxl_image_encoder
recommended: False
description: Required model for using IP-Adapters with SDXL models

View File

@@ -1,80 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
personalization_config:
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ['sculpture']
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder

View File

@@ -14,6 +14,7 @@ import os
import re
import shutil
import sqlite3
import uuid
from pathlib import Path
import PIL
@@ -26,7 +27,6 @@ from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.shortcuts import message_dialog
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.util.misc import uuid_string
app_config = InvokeAIAppConfig.get_config()
@@ -421,7 +421,7 @@ VALUES ('{filename}', 'internal', 'general', {width}, {height}, null, null, '{me
return rows[0][0]
else:
board_date_string = datetime.datetime.utcnow().date().isoformat()
new_board_id = uuid_string()
new_board_id = str(uuid.uuid4())
sql_insert_board = f"INSERT INTO boards (board_id, board_name, created_at, updated_at) VALUES ('{new_board_id}', '{board_name}', '{board_date_string}', '{board_date_string}')"
self.cursor.execute(sql_insert_board)
self.connection.commit()

View File

@@ -45,7 +45,7 @@ from invokeai.frontend.install.widgets import (
)
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger()
logger = InvokeAILogger.getLogger()
# build a table mapping all non-printable characters to None
# for stripping control characters
@@ -101,12 +101,11 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
"STARTER MODELS",
"MAIN MODELS",
"CONTROLNETS",
"IP-ADAPTERS",
"LORA/LYCORIS",
"TEXTUAL INVERSION",
],
value=[self.current_tab],
columns=6,
columns=5,
max_height=2,
relx=8,
scroll_exit=True,
@@ -131,13 +130,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.ipadapter_models = self.add_model_widgets(
model_type=ModelType.IPAdapter,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.lora_models = self.add_model_widgets(
model_type=ModelType.Lora,
@@ -351,7 +343,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.starter_pipelines,
self.pipeline_models,
self.controlnet_models,
self.ipadapter_models,
self.lora_models,
self.ti_models,
]
@@ -541,7 +532,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.starter_pipelines,
self.pipeline_models,
self.controlnet_models,
self.ipadapter_models,
self.lora_models,
self.ti_models,
]
@@ -563,25 +553,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
if downloads := section.get("download_ids"):
selections.install_models.extend(downloads.value.split())
# NOT NEEDED - DONE IN BACKEND NOW
# # special case for the ipadapter_models. If any of the adapters are
# # chosen, then we add the corresponding encoder(s) to the install list.
# section = self.ipadapter_models
# if section.get("models_selected"):
# selected_adapters = [
# self.all_models[section["models"][x]].name for x in section.get("models_selected").value
# ]
# encoders = []
# if any(["sdxl" in x for x in selected_adapters]):
# encoders.append("ip_adapter_sdxl_image_encoder")
# if any(["sd15" in x for x in selected_adapters]):
# encoders.append("ip_adapter_sd_image_encoder")
# for encoder in encoders:
# key = f"any/clip_vision/{encoder}"
# repo_id = f"InvokeAI/{encoder}"
# if key not in self.all_models:
# selections.install_models.append(repo_id)
class AddModelApplication(npyscreen.NPSAppManaged):
def __init__(self, opt):
@@ -681,7 +652,7 @@ def process_and_execute(
translator = StderrToMessage(conn_out)
sys.stderr = translator
sys.stdout = translator
logger = InvokeAILogger.get_logger()
logger = InvokeAILogger.getLogger()
logger.handlers.clear()
logger.addHandler(logging.StreamHandler(translator))
@@ -794,7 +765,7 @@ def main():
if opt.full_precision:
invoke_args.extend(["--precision", "float32"])
config.parse_args(invoke_args)
logger = InvokeAILogger().get_logger(config=config)
logger = InvokeAILogger().getLogger(config=config)
if not config.model_conf_path.exists():
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,280 +0,0 @@
import{w as s,hY as T,v as l,a2 as I,hZ as R,ae as V,h_ as z,h$ as j,i0 as D,i1 as F,i2 as G,i3 as W,i4 as K,aG as Y,i5 as Z,i6 as H}from"./index-94062f76.js";import{M as U}from"./MantineProvider-a057bfc9.js";var P=String.raw,E=P`
:root,
:host {
--chakra-vh: 100vh;
}
@supports (height: -webkit-fill-available) {
:root,
:host {
--chakra-vh: -webkit-fill-available;
}
}
@supports (height: -moz-fill-available) {
:root,
:host {
--chakra-vh: -moz-fill-available;
}
}
@supports (height: 100dvh) {
:root,
:host {
--chakra-vh: 100dvh;
}
}
`,B=()=>s.jsx(T,{styles:E}),J=({scope:e=""})=>s.jsx(T,{styles:P`
html {
line-height: 1.5;
-webkit-text-size-adjust: 100%;
font-family: system-ui, sans-serif;
-webkit-font-smoothing: antialiased;
text-rendering: optimizeLegibility;
-moz-osx-font-smoothing: grayscale;
touch-action: manipulation;
}
body {
position: relative;
min-height: 100%;
margin: 0;
font-feature-settings: "kern";
}
${e} :where(*, *::before, *::after) {
border-width: 0;
border-style: solid;
box-sizing: border-box;
word-wrap: break-word;
}
main {
display: block;
}
${e} hr {
border-top-width: 1px;
box-sizing: content-box;
height: 0;
overflow: visible;
}
${e} :where(pre, code, kbd,samp) {
font-family: SFMono-Regular, Menlo, Monaco, Consolas, monospace;
font-size: 1em;
}
${e} a {
background-color: transparent;
color: inherit;
text-decoration: inherit;
}
${e} abbr[title] {
border-bottom: none;
text-decoration: underline;
-webkit-text-decoration: underline dotted;
text-decoration: underline dotted;
}
${e} :where(b, strong) {
font-weight: bold;
}
${e} small {
font-size: 80%;
}
${e} :where(sub,sup) {
font-size: 75%;
line-height: 0;
position: relative;
vertical-align: baseline;
}
${e} sub {
bottom: -0.25em;
}
${e} sup {
top: -0.5em;
}
${e} img {
border-style: none;
}
${e} :where(button, input, optgroup, select, textarea) {
font-family: inherit;
font-size: 100%;
line-height: 1.15;
margin: 0;
}
${e} :where(button, input) {
overflow: visible;
}
${e} :where(button, select) {
text-transform: none;
}
${e} :where(
button::-moz-focus-inner,
[type="button"]::-moz-focus-inner,
[type="reset"]::-moz-focus-inner,
[type="submit"]::-moz-focus-inner
) {
border-style: none;
padding: 0;
}
${e} fieldset {
padding: 0.35em 0.75em 0.625em;
}
${e} legend {
box-sizing: border-box;
color: inherit;
display: table;
max-width: 100%;
padding: 0;
white-space: normal;
}
${e} progress {
vertical-align: baseline;
}
${e} textarea {
overflow: auto;
}
${e} :where([type="checkbox"], [type="radio"]) {
box-sizing: border-box;
padding: 0;
}
${e} input[type="number"]::-webkit-inner-spin-button,
${e} input[type="number"]::-webkit-outer-spin-button {
-webkit-appearance: none !important;
}
${e} input[type="number"] {
-moz-appearance: textfield;
}
${e} input[type="search"] {
-webkit-appearance: textfield;
outline-offset: -2px;
}
${e} input[type="search"]::-webkit-search-decoration {
-webkit-appearance: none !important;
}
${e} ::-webkit-file-upload-button {
-webkit-appearance: button;
font: inherit;
}
${e} details {
display: block;
}
${e} summary {
display: list-item;
}
template {
display: none;
}
[hidden] {
display: none !important;
}
${e} :where(
blockquote,
dl,
dd,
h1,
h2,
h3,
h4,
h5,
h6,
hr,
figure,
p,
pre
) {
margin: 0;
}
${e} button {
background: transparent;
padding: 0;
}
${e} fieldset {
margin: 0;
padding: 0;
}
${e} :where(ol, ul) {
margin: 0;
padding: 0;
}
${e} textarea {
resize: vertical;
}
${e} :where(button, [role="button"]) {
cursor: pointer;
}
${e} button::-moz-focus-inner {
border: 0 !important;
}
${e} table {
border-collapse: collapse;
}
${e} :where(h1, h2, h3, h4, h5, h6) {
font-size: inherit;
font-weight: inherit;
}
${e} :where(button, input, optgroup, select, textarea) {
padding: 0;
line-height: inherit;
color: inherit;
}
${e} :where(img, svg, video, canvas, audio, iframe, embed, object) {
display: block;
}
${e} :where(img, video) {
max-width: 100%;
height: auto;
}
[data-js-focus-visible]
:focus:not([data-focus-visible-added]):not(
[data-focus-visible-disabled]
) {
outline: none;
box-shadow: none;
}
${e} select::-ms-expand {
display: none;
}
${E}
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(h=>{const f=h==="system"?w():h;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const h=a.get();if(h){c(h);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function m(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>m(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((m(e)||m(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=m(e)?e(...t):e,a=m(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function he({children:e}){const{i18n:o}=Y(),n=o.dir(),r=l.useMemo(()=>ie({...Z,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(U,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:H,children:e})})}const ve=l.memo(he);export{ve as default};

Some files were not shown because too many files have changed in this diff Show More